import pandas as pd
import numpy as np
import torch as pt
from torch.autograd import Variable
from torch.nn.functional import cross_entropy, softmax, kl_div, log_softmax, sigmoid, mse_loss, binary_cross_entropy

from Prediction.VoteResult import SimpleVoteResult, ComplexVoteResult

# Set dims to 4 to consider historical info from one year with 4 columns for four parties
dims = 4


class MultiPredictor:

    def __init__(self, hist, info, dims, num_classes, gamma=0.05, n_iter=30000):
        hist = hist.dropna(axis=0)
        info = info.loc[hist.index]
        U, s, V = np.linalg.svd(hist.values.astype(float))
        self.num_classes = num_classes # Assuming 4 classes with the 5th being calculable
        self.d = dims
        self.info = info
        self.Xs = pd.DataFrame(U[:, :self.d*self.num_classes] * s[:self.d*self.num_classes], index=hist.index)
        self.Xs[self.num_classes*self.d] = 1  # adding bias

        self.d = self.num_classes*self.d + 1
        self.ks = pd.DataFrame(info.VoteCount.values, index=info.index)
        self.params = dict()

        self.gamma = gamma
        self.n_iter = n_iter

    def init_params(self, year):
        if year in self.params:
            return self.params[year]
        weights = Variable(pt.zeros(self.d, self.num_classes))
        weights.requires_grad = True
        return weights

    def separate(self, res: list[ComplexVoteResult]) -> tuple[tuple[list[str], np.ndarray[float]], list[str]]:
        observed = []
        percentages = []
        not_seen = set(self.Xs.index)
        for result in res:
            if result.identifier in not_seen:
                observed.append(result.identifier)
                percentages.append(result.percentages)
                not_seen.remove(result.identifier)

        not_seen = list(set(self.Xs.index).difference(set(observed)))
        return (observed, np.array(percentages)), not_seen

    def predict(self, res, year=1, tol=1e-12):
        weights = self.init_params(year)

        (observed, percentages), not_seen = self.separate(res)

        X_train = Variable(pt.from_numpy(self.Xs.loc[observed].values.astype(np.float32)))
        y_train = Variable(pt.from_numpy(percentages.astype(np.float32)))
        X_pred = Variable(pt.from_numpy(self.Xs.loc[not_seen].values.astype(np.float32)))
        k_train = Variable(pt.from_numpy(self.ks.loc[observed].values.flatten().astype(np.float32)))
        optimizer = pt.optim.Adam([weights], lr=self.gamma)

        losses = list()
        prev_loss = 1e9
        for i in range(self.n_iter):
            optimizer.zero_grad()
            inner = X_train @ weights
            y_hat = softmax(inner, dim=1)

            loss = (binary_cross_entropy(
                y_hat, y_train, reduction='sum')
                    / np.sum(k_train.data.numpy()))
            loss.backward()
            optimizer.step()

            new_loss = loss.item()
            if prev_loss < new_loss:
                self.gamma /= 2
                for param_group in optimizer.param_groups:
                    param_group['lr'] = self.gamma

            if tol is not None and abs(prev_loss - new_loss) < tol:
                break
            losses.append(new_loss)
            prev_loss = new_loss

        self.params[year] = weights
        w_pred = weights.detach()
        y_pred = softmax(X_pred @ w_pred, dim=1).tolist()

        return {location: pred for location, pred in zip(not_seen, y_pred)}, losses

    def count_results(self, obs: list[ComplexVoteResult], uobs: list[ComplexVoteResult], pred):
        total_votes = {i: 0 for i in range(self.num_classes + 1)}  # One additional for the calculable class
        total_valid_votes = 0

        for res in uobs:
            identifier = res.identifier
            num_valid = res.count
            if identifier in pred:
                predicted_distribution = pred[identifier]
                for i in range(self.num_classes):
                    total_votes[i] += predicted_distribution[i] * num_valid
                total_valid_votes += num_valid

        for res in obs:
            total_valid_votes += res.count
            for i in range(self.num_classes):
                total_votes[i] += res.percentages[i] * res.count

        # Calculate the remaining class percentage
        total_votes[self.num_classes] = total_valid_votes - sum(total_votes.values())

        return {f"Party {i + 1}": total_votes[i] / total_valid_votes for i in range(self.num_classes + 1)}
