from pathlib import Path

import pandas as pd
from matplotlib import pyplot as plt

from DataPreparation.Preparation import cz_preparation, load_cz_2024_election_results, align_df_indexes
from Prediction.SubmatrixMulti import MultiPredictor
from Prediction.VoteResult import ComplexVoteResult
from Testing.TestingUtils import accumulate_results


def EU_VOLBY_2024():
    hist, info = cz_preparation()

    dims = 2
    predictor = MultiPredictor(hist, info, dims, 4)

    rounds = 60
    predicted_accumulated = []
    observed_accumulated = []

    for round_number in range(1, rounds + 1):
        observations = load_cz_2024_election_results(f'DataPreparation/cz_euro_2024/results-{round_number}.json')
        predictions, losses = predictor.predict(observations)
        predicted_results = [
            ComplexVoteResult(
                identifier=key,
                percentages=[val for val in value],
                count=info.loc[key, 'VoteCount'] if key in info.index else 0
            )
            for key, value in predictions.items()
        ]

        predicted_accumulated.append(accumulate_results(predicted_results+observations))
        observed_accumulated.append(accumulate_results(observations))

    df_predicted = pd.DataFrame(predicted_accumulated,
                                columns=['party_14_percent', 'party_23_percent', 'party_17_percent', 'other_percent'])
    df_observed = pd.DataFrame(observed_accumulated,
                               columns=['party_14_percent', 'party_23_percent', 'party_17_percent', 'other_percent'])

    plt.figure(figsize=(18, 12))
    plt.plot(df_predicted['party_14_percent'], label='Predicted Party 14 %', linestyle='--', color='red')
    plt.plot(df_observed['party_14_percent'], label='Observed Party 14 %', color='red')
    plt.plot(df_predicted['party_23_percent'], label='Predicted Party 23 %', linestyle='--', color='blue')
    plt.plot(df_observed['party_23_percent'], label='Observed Party 23 %', color='blue')
    plt.plot(df_predicted['party_17_percent'], label='Predicted Party 17 %', linestyle='--', color='green')
    plt.plot(df_observed['party_17_percent'], label='Observed Party 17 %', color='green')
    plt.plot(df_predicted['other_percent'], label='Predicted Other %', linestyle='--', color='black')
    plt.plot(df_observed['other_percent'], label='Observed Other %', color='black')

    plt.xlabel('Rounds')
    plt.ylabel('Percentage')
    plt.title('Comparison of Predicted and Observed Percentages Over Time')
    plt.legend()
    plt.grid(True)
    plt.show()


def main_2():
    align_df_indexes(Path('../data_cz/eurovolby/'))


if __name__ == '__main__':
    EU_VOLBY_2024()
