import json
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np

import pandas as pd
import os

from DataPreparation.Preparation import merge_dfs, calculate_percentages, calculate_counts
from Prediction.ComplexSubMatrix import ComplexSubmatrix
from Prediction.Submatrix import Predictor
from Prediction.VoteResult import SimpleVoteResult
from Testing.TestingUtils import get_samples_for_batching, simulate_vote_result_batching, simulate_partial_vote_result


def load_data(path: str, index_col: str = 'Okres'):
    def is_second_round(name: str):
        return '_2.' in name

    base_path = Path(f'../../../{path}/raw')
    names = list(filter(is_second_round, os.listdir(base_path)))
    dfs = [(name, pd.read_csv(base_path.joinpath(name), index_col=index_col).replace(0, 1)) for name in names]
    dfs_with_counts = [(name, calculate_counts(df), calculate_percentages(df)) for (name, df) in dfs]

    return dfs_with_counts


def get_vote_results(df: pd.DataFrame):
    name = df.columns[0]
    observations = df[df[name] != -1]
    missing = df[df[name] == -1]

    # Convert the DataFrame to a list of SimpleVoteResult instances
    observed_results = [SimpleVoteResult(identifier=index, percentage=row[name], count=row['VoteCount'])
                        for index, row in observations.iterrows()]
    unseen_results = [SimpleVoteResult(identifier=index, percentage=0, count=row['VoteCount']) for index, row in
                      missing.iterrows()]
    return observed_results, unseen_results


def perform_simulation(test: pd.DataFrame, info: pd.DataFrame, predictor: Predictor, number_of_batches=15):
    predictions = []
    test_batches = simulate_vote_result_batching(test, number_of_batches)
    for batch in test_batches[1:]:
        t = pd.concat([batch, info], axis=1)
        obs, uobs = get_vote_results(t)

        yp, _ = predictor.predict(obs)

        res = predictor.count_results(obs, uobs, yp)
        current_res = predictor.count_results(obs, [], yp)

        observed_count, total_count, count_percentage = calculate_statistics(obs, uobs)
        predictions.append({
            "prediction": res,
            "current_result": current_res,
            "count_percentage": count_percentage
        })

    return predictions


def calculate_statistics(obs, uobs):
    observed_count = sum([result.count for result in obs])
    total_count = observed_count + sum([result.count for result in uobs])
    count_percentage = observed_count / total_count * 100
    return observed_count, total_count, count_percentage


def slovak():
    data = load_data('data')

    pure_data = [percentage_df for (_, _, percentage_df) in data]
    names = [name.split('.')[0] for (name, _, _) in data]

    df = merge_dfs(pure_data, names)
    pairs = get_samples_for_batching(df)
    for (hist, test) in pairs[1:]:
        # test_batches = simulate_vote_result_batching(test, 15)
        name = names.index(test.columns[0])
        datum = data[name]
        info = datum[1]
        pr = Predictor(hist, info, len(hist.columns))
        preds = perform_simulation(test, info, pr, 25)
        ypred = [p['prediction'] for p in preds]
        count = [p['current_result'] for p in preds]
        x = [p['count_percentage'] for p in preds]
        plot_results_with_diff(ypred, count, x, test.columns[0])


def plot_results(ypred, count, x, name="", plot=None):
    if plot is None:
        _, plot = plt.subplots()
    plot.set_xticks(np.arange(0, 101, 5))
    plot.plot(x, ypred, linewidth=4, label='Predicted result')
    plot.plot(x, count, linewidth=4, label='Counted result', alpha=0.5)
    plot.axhline(count[-1], linewidth=2, color='black', label='Final result', alpha=0.5)
    plot.set_xlabel('Percentage of counted votes [%]')
    plot.set_ylabel('Results [%]')
    plot.legend(loc='lower right')
    plot.set_title(f'Prediction vs. Count for {name}')
    y_min, y_max = plot.get_ylim()
    if y_min < 0.5 or 0.5 < y_max:
        plot.axhspan(max(0.5, y_min), y_max, facecolor='green', alpha=0.1)
    plot.grid()
    return plot


def plot_difference(y_pred, count, x, name="", plot=None):
    if plot is None:
        _, plot = plt.subplots()
    plot.set_xticks(np.arange(0, 101, 5))
    final = count[-1]
    diffs = [abs(y - final) * 100 for y in y_pred]
    plot.plot(x, diffs, linewidth=4, label='Difference')
    plot.set_xlabel('Percentage of counted votes [%]')
    plot.set_ylabel('Difference [%]')
    plot.set_title(f'Absolute error')
    plot.grid()
    return plot


def plot_results_with_diff(ypred, count, x, name=""):
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 5))
    plot_results(ypred, count, x, name, ax1)
    plot_difference(ypred, count, x, name, ax2)
    plt.tight_layout()
    plt.show()


def mix_test():
    data = load_data('data')

    pure_data = [percentage_df for (_, _, percentage_df) in data]
    names = [name.split('.')[0] for (name, _, _) in data]

    df = merge_dfs(pure_data, names)
    pairs = get_samples_for_batching(df)

    predictions = []

    for (hist, test) in pairs:
        batch = test

        # test_batches = simulate_vote_result_batching(test, 15)
        datum = data[names.index(test.columns[0])]
        info = datum[1]
        simulated_info = simulate_partial_vote_result(info, 0.05, 0.15)
        pr = ComplexSubmatrix(hist, datum[1], len(hist.columns))

        t = pd.concat([batch, simulated_info], axis=1)
        obs, uobs = get_vote_results(t)
        (seen, _), nseen = pr.separate(obs)
        yp, _ = pr.predict(obs)
        obs_r = [o for o in obs if o.identifier in seen]
        uobs_r = [o for o in obs if o.identifier in nseen]
        res = pr.count_results(obs_r, uobs_r, yp)
        current_res = pr.count_results(obs_r, [], yp)

        observed_count, total_count, count_percentage = calculate_statistics(obs_r, uobs_r)
        predictions.append({
            "prediction": res,
            "current_result": current_res,
            "count_percentage": count_percentage
        })


if __name__ == '__main__':
    slovak()

    # mix_test()
