import numpy as np
import pandas as pd
import pickle
import torch as pt

from torch.autograd import Variable
from torch import sigmoid
from torch.nn.functional import binary_cross_entropy

from Prediction.VoteResult import SimpleVoteResult


class Predictor:

    def __init__(self, hist, info, dims, gamma=0.01, n_iter=30000):
        hist = hist.dropna(axis=0)
        info = info.loc[hist.index]
        U, s, V = np.linalg.svd(hist.values.astype(float))

        self.info = info
        self.Xs = pd.DataFrame(U[:, :dims] * s[:dims], index=hist.index)
        self.Xs[dims] = 1  # adding bias
        self.ks = pd.DataFrame(info.VoteCount.values, index=info.index)
        self.params = dict()
        self.d = dims + 1
        self.gamma = gamma
        self.n_iter = n_iter

    def init_params(self):
        weights = Variable(pt.zeros(self.d))
        weights.requires_grad = True
        return weights

    def separate(self, res: list[SimpleVoteResult]) -> tuple[tuple[list[str], np.ndarray[float]], list[str]]:
        observed = []
        percentages = []
        not_seen = set(self.Xs.index)
        for result in res:
            if result.identifier in self.Xs.index:
                observed.append(result.identifier)
                percentages.append(result.percentage)
                not_seen.remove(result.identifier)

        not_seen = list(set(self.Xs.index).difference(set(observed)))
        return (observed, np.array(percentages)), not_seen

    def predict(self, res, tol=1e-9):

        weights = self.init_params()

        (observed, percentages), not_seen = self.separate(res)

        X_train = Variable(
            pt.from_numpy(self.Xs.loc[observed].values.astype(np.float32)))
        y_train = Variable(pt.from_numpy(percentages.astype(np.float32)))
        X_pred = Variable(
            pt.from_numpy(self.Xs.loc[not_seen].values.astype(np.float32)))
        k_train = Variable(
            pt.from_numpy(
                self.ks.loc[observed].values.flatten().astype(np.float32)))

        losses = list()
        prev_loss = 1e9
        for i in range(self.n_iter):
            inner = X_train @ weights
            y_hat = sigmoid(inner)
            loss = (binary_cross_entropy(
                y_hat, y_train, weight=k_train, reduction='sum')
                    / np.sum(k_train.data.numpy()))
            loss.backward()
            weights.data -= self.gamma * weights.grad.data
            weights.grad.data.zero_()
            new_loss = loss.item()
            if prev_loss < new_loss:
                self.gamma /= 2
            if tol is not None and abs(prev_loss - new_loss) < tol:
                break
            losses.append(new_loss)
            prev_loss = new_loss

        w_pred = weights.detach()
        y_pred = sigmoid(X_pred @ w_pred).tolist()

        return {location: pred for location, pred in zip(not_seen, y_pred)}, losses

    def count_results(self, obs: list[SimpleVoteResult], uobs: list[SimpleVoteResult], pred):
        yes_votes, total_votes = 0, 0
        # unobserved results
        for res in uobs:
            identifier = res.identifier
            num_valid = res.count
            if identifier in pred:
                yes_votes += pred[identifier] * num_valid
                total_votes += num_valid
        # observed results
        for res in obs:
            yes_votes += res.count * res.percentage
            total_votes += res.count
        return yes_votes / total_votes
