import torch
import torch.nn.functional as F
from sklearn import metrics
from dgl import backend
import matplotlib.pyplot as plt
import numpy as np

def ntype_one_hot_init(graph):
    n_node_types = len(graph.ntypes)
    for i, ntype in enumerate(graph.ntypes):
        t = F.one_hot(torch.full([graph.num_nodes(ntype=ntype)], fill_value=i), num_classes=n_node_types)
        graph.ndata["f"] = {ntype: t.to(torch.float32)}

def node_one_hot_init(graph):
    size = graph.num_nodes()
    t = torch.diag(torch.full([size], fill_value=1, dtype=torch.int8))
    curr_row = 0
    for ntype in graph.ntypes:
        graph.ndata["f"] = {ntype: t[curr_row : curr_row + graph.num_nodes(ntype)]}
        curr_row += graph.num_nodes(ntype)

def dec_pert(ratio=0.1):
    def _dec_pert(graph):
        graph = graph.clone()
        for c_etype in graph.canonical_etypes:
            utype, _, vtype = c_etype
            num_new_edges = int(graph.num_edges(c_etype) * ratio)
            src = backend.randint([num_new_edges], graph.idtype, graph.device, low=0, high=graph.num_nodes(utype))
            dst = backend.randint([num_new_edges], graph.idtype, graph.device, low=0, high=graph.num_nodes(vtype))
            graph.add_edges(src, dst, etype=c_etype, data={"id": torch.full((num_new_edges, 1), -1)})
        return graph
    return _dec_pert

def compute_pretext_loss(predictions, labels, dev):
    logits = torch.Tensor().to(dev)
    targets = torch.Tensor().to(dev)
    
    for etype in predictions.keys():
       logits = torch.cat([logits, predictions[etype]]) 
       
       # negative edges have id == -1
       target = (labels[etype] != -1).int().flatten()
       targets = torch.cat([targets, target])
    
    return F.binary_cross_entropy_with_logits(logits.squeeze(), targets)

def compute_downstream_loss(predictions, labels):
    return F.cross_entropy(predictions, labels)

def plot_disease_mat(y_true, y_pred):
    confusion = metrics.confusion_matrix(y_true, y_pred)
    print(confusion)

    plt.imshow(confusion, cmap=plt.cm.Blues)

    plt.title('Confusion matrix')
    plt.xticks(np.arange(3), ['None', 'Treats' , 'Palliates'])
    plt.yticks(np.arange(3), ['None', 'Treats' , 'Palliates'])
    plt.xlabel('Predicted')
    plt.ylabel('True')

    for i in range(3):
        for j in range(3):
            plt.text(j, i, str(confusion[i][j]), horizontalalignment='center', verticalalignment='center')

    plt.colorbar()
    plt.show()

def plot_side_effects_mat(y_true, y_pred):
    confusion = metrics.confusion_matrix(y_true, y_pred)
    print(confusion)

    plt.imshow(confusion, cmap=plt.cm.Blues)

    plt.title('Confusion matrix')
    plt.xticks(np.arange(2), ['None', 'causes SE'])
    plt.yticks(np.arange(2), ['None', 'causes SE'])
    plt.xlabel('Predicted')
    plt.ylabel('True')

    for i in range(2):
        for j in range(2):
            plt.text(j, i, str(confusion[i][j]), horizontalalignment='center', verticalalignment='center')

    plt.colorbar()
    plt.show()

def plot_mat(y_true, y_pred):
    confusion = metrics.confusion_matrix(y_true, y_pred)
    print(confusion)

    plt.imshow(confusion, cmap=plt.cm.Blues)

    plt.title('Confusion matrix')
    plt.xticks(np.arange(3), ['None', 'Treats' , 'Palliates'])
    plt.yticks(np.arange(3), ['None', 'Treats' , 'Palliates'])
    plt.xlabel('Predicted')
    plt.ylabel('True')

    for i in range(3):
        for j in range(3):
            plt.text(j, i, str(confusion[i][j]), horizontalalignment='center', verticalalignment='center')

    plt.colorbar()
    plt.show()

def evaluate(dataset, encoder, decoder, graph, dev, plot=False, plot_fun=plot_disease_mat):
    node_emb = encoder(graph, graph.ndata["f"])
    preds = torch.tensor([], dtype=torch.int8).to(dev)
    labels= torch.tensor([], dtype=torch.int8).to(dev)

    for x, y in dataset:
        src_nodes = x[0].to(dev)
        dst_nodes = x[1].to(dev)
        y = y.to(dev)
        
        pred = decoder(src_nodes, dst_nodes, node_emb)
        pred = torch.argmax(pred, dim=1) 
        preds = torch.cat([preds, pred])
        labels = torch.cat([labels, y])

    preds = preds.to("cpu")
    labels = labels.to("cpu")

    if plot:
        plot_fun(labels, preds)

    prec = metrics.precision_score(preds, labels, average="macro")
    acc = metrics.accuracy_score(preds, labels)
    f1 = metrics.f1_score(preds, labels, average="macro")

    return acc, prec, f1


def train_pretext(encoder, decoder, enc_dataset, dec_dataset, loss_fn, optimizer, epochs, dev):
    encoder.to(dev)
    decoder.to(dev)
    
    encoder.train()
    decoder.train()

    for epoch in range(epochs):
        pertrubed_graph = enc_dataset[epoch]
        pertrubed_graph = pertrubed_graph.to(dev)

        # calculate node embedings
        node_emb = encoder(pertrubed_graph, pertrubed_graph.ndata["f"])

        expanded_graph = dec_dataset[epoch]
        expanded_graph = expanded_graph.to(dev)

        # calculate predictions
        pred = decoder(expanded_graph, node_emb)

        loss = loss_fn(pred, expanded_graph.edata["id"], dev)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        print(f"epoch {epoch} loss = {loss}")


def train_downstream(encoder, decoder, enc_dataset, dec_dataset, val_dataset, optimizer, epochs, dev):
    encoder.to(dev)
    decoder.to(dev)

    for epoch in range(epochs):
        encoder.train()
        decoder.train()
        graph = enc_dataset[epoch]
        graph = graph.to(dev)

        # calculate node embedings
        tloss = 0
        for pairs, y in dec_dataset:
            src_nodes = pairs[0].to(dev)
            dst_nodes = pairs[1].to(dev)
            y = y.to(dev)
            
            node_emb = encoder(graph, graph.ndata["f"])
            pred = decoder(src_nodes, dst_nodes, node_emb)

            loss = compute_downstream_loss(pred, y)
            tloss += loss

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        with torch.no_grad():
            encoder.eval()
            decoder.eval()
            acc, prec, f1 = evaluate(val_dataset, encoder, decoder, graph, dev)

        print(f"epoch {epoch} loss: {tloss} val_acc: {acc} val_prec: {prec} val_f1: {f1}")
