from pathlib import Path

import pandas as pd

from Prediction.SubmatrixMulti import MultiPredictor
from Prediction.VoteResult import ComplexVoteResult


def testing_multiparties():
    hist_csv = 'hist2.csv'
    info_csv = 'hist_count2.csv'
    base_path = Path(f'../testingData/')

    # Load data into DataFrames
    hist = pd.read_csv(base_path.joinpath(hist_csv), index_col='ID')
    info = pd.read_csv(base_path.joinpath(info_csv), index_col='ID')

    observations = [
        ComplexVoteResult(identifier=101, percentages=[0.2, 0.5, 0.3], count=2000),
        ComplexVoteResult(identifier=102, percentages=[0.1, 0.8, 0.1], count=3000),
        ComplexVoteResult(identifier=108, percentages=[0.69, 0.1, 0.21], count=4000),
    ]

    dims = 2

    predictor = MultiPredictor(hist, info, dims, 3)
    predictions, losses = predictor.predict(observations)
    print("Predictions:")
    for key, value in predictions.items():
        print(f"District {key}: {value}")


if __name__ == '__main__':
    testing_multiparties()
