import json
import os
from pathlib import Path

import pandas as pd

from Prediction.VoteResult import ComplexVoteResult


### UNIVERSAL PREPARATION
def get_okres_okrsok_mapping(df):
    new_keys = list(map(lambda x: x.split('-'), df.index.array))
    new_keys_d = {}

    for okres, okrsok in new_keys:
        curr = new_keys_d.get(okres, [])
        curr.append(okrsok)
        new_keys_d[okres] = curr
    return new_keys_d


def align_both_keys(old_keys, new_keys):
    zipped = {}
    for key in new_keys.keys():
        keys_2018 = new_keys[key]
        keys_2013 = old_keys.get(key, ['-1' for _ in new_keys[key]])
        while len(keys_2013) < len(keys_2018):
            keys_2013.append('-1')
        new_res = list(zip(keys_2018, keys_2013))
        zipped[key] = new_res
    return zipped


def reset_df(df_with_wrong_keys, zipped_keys):
    df_fixed = pd.DataFrame(columns=list(df_with_wrong_keys.columns))
    df_fixed.index.name = 'ID'
    for key in list(zipped_keys.keys()):
        for new, old in zipped_keys[key]:
            old_key = f"{key}-{int(old)}"
            new_key = f"{key}-{int(new)}"
            if old == '-1':
                df_fixed.loc[new_key] = [0 for _ in df_with_wrong_keys.columns]
                continue

            row = df_with_wrong_keys.loc[old_key].copy()
            df_fixed.loc[new_key] = row
    return df_fixed


### CZ DATA PREPARATION
def clear_xslx_cz(base_path, pth):
    to_open = base_path.joinpath(pth)
    df = pd.read_csv(to_open, sep=';')

    allowed_column = ['OBEC', 'OKRSEK', 'ESTRANA', 'POC_HLASU']

    for column in df.columns:
        if column not in allowed_column:
            df = df.drop(column, axis=1)

    df['OBEC'] = df['OBEC'].astype(str)
    df['OKRSEK'] = df['OKRSEK'].astype(str)
    df['ID'] = df[['OBEC', 'OKRSEK']].agg('-'.join, axis=1)

    pivot_df = df.pivot_table(index='ID', columns='ESTRANA', values='POC_HLASU', fill_value=0)
    pivot_df.columns = [col if col == 'ID' else str(col) + '-Votes' for col in pivot_df.columns]
    columns_to_convert = [col for col in pivot_df.columns if
                          'Votes' in col]  # This selects columns that end with 'Votes'
    pivot_df[columns_to_convert] = pivot_df[columns_to_convert].astype(int)
    parts = pth.split('.')
    pth = parts[0] + '-f.' + parts[1]

    pivot_df.to_csv(base_path.joinpath(pth))


'''
This function relies on the order of indexes present in the dfs. within a given directory, it reads 
the original dataframes, and then aligns their keys to 
'''


def align_df_indexes(base_base_path):

    base_path = base_base_path.joinpath('fixed')
    results_path = base_base_path.joinpath('fixed_aligned')
    names = os.listdir(base_path)

    dfs = [(name, pd.read_csv(base_path.joinpath(name), index_col='ID')) for name in names]
    mem = align_df_indexes_mem([df for name, df in dfs])
    fixed_dfs = list(zip(names, mem))

    print(len(fixed_dfs))
    for name, df in fixed_dfs:
        aligned_name = results_path.joinpath(name)
        print(aligned_name)
        if os.path.exists(aligned_name):
            continue
        df.to_csv(aligned_name)


def calculate_percentages(df_):
    df = df_.copy()
    df['Votes'] = df.sum(axis=1)
    df['Votes'] = df.iloc[:, 0] / df['Votes']
    # drop columns which do not contain 'Votes' or 'ID'
    df = df.loc[:, df.columns.str.contains('Votes|Okres')]
    return df


def calculate_counts(df):
    df_c = df.copy()
    df_c['VoteCount'] = df.sum(axis=1)
    df_c = df_c.loc[:, df_c.columns.str.contains('VoteCount|Okres')]
    return df_c


def align_df_indexes_mem(dfs: list[pd.DataFrame]) -> list[pd.DataFrame]:
    mappings = [get_okres_okrsok_mapping(df) for df in dfs]

    zips = [align_both_keys(old, mappings[-1]) for old in mappings[:-1]]
    to_fix = zip(dfs, zips)
    fixed_dfs = [reset_df(old_df, mapping) for (old_df, mapping) in to_fix]

    return fixed_dfs


def merge_dfs(dfs: list[pd.DataFrame], names: list[str]) -> pd.DataFrame:
    # Concatenate the DataFrames along columns (axis=1)
    merged_df = pd.concat(dfs, axis=1)
    # Set new column names from the names list
    merged_df.columns = names
    return merged_df


def prepare_df_cz_eu(year, folder_path, party_columns):
    df = pd.read_csv(folder_path.joinpath(f"eu_volby_{year}.csv"), index_col='ID')
    party_ids = [f"{party_id}-Votes" for party_id in party_columns]
    other_id = f"Other-{year}"
    other_votes = df.drop(columns=party_ids).sum(axis=1)
    df_processed = df[party_ids].copy()
    df_processed[other_id] = other_votes
    df_processed = df_processed.div(df_processed.sum(axis=1), axis=0) * 100

    # Reorder the columns as specified
    df_processed = df_processed[party_ids + [other_id]]

    return df_processed


def cz_preparation():
    base_path = Path('../data_cz/eurovolby/fixed_aligned')
    years = [(2014, [16, 32, 7]), (2019, [30, 27, 26])]

    dfs = [(year, prepare_df_cz_eu(year, base_path, parties)) for year, parties in years]
    column_names = []
    for name, df in dfs:
        column_names.extend(df.columns)

    df = merge_dfs([df for _, df in dfs], column_names)
    results_60_path = Path('DataPreparation/cz_euro_2024/results-60.json')
    with open(results_60_path, 'r') as f:
        data = json.load(f)
    info_data = {
        "ID": [],
        "VoteCount": []
    }

    for record in data:
        info_data["ID"].append(record['key'])
        info_data["VoteCount"].append(record.get('count', 0))

    info = pd.DataFrame(info_data).set_index("ID")
    info = info[~info.index.duplicated(keep='first')]

    return df, info


def load_cz_2024_election_results(file_path):
    results = []
    with open(file_path, 'r') as f:
        data = json.load(f)
        for record in data:
            percentages = [
                record.get('party_14_percent', 0),
                record.get('party_23_percent', 0),
                record.get('party_17_percent', 0),
                record.get('other_percent', 0)
            ]
            result = ComplexVoteResult(
                identifier=record['key'],
                percentages=percentages,
                count=record.get('count', 0)
            )
            results.append(result)
    return results
