import os
import random
from random import randint, choice
import dgl
from dgl.data import DGLDataset
from dgl.transforms import DropEdge
import numpy as np
from numpy.random import shuffle
import torch



abbr_dict = {
    ("Anatomy", "downregulates", "Gene"): "AdG",
    ("Anatomy", "expresses", "Gene"): "AeG",
    ("Anatomy", "upregulates", "Gene"): "AuG",
    ("Compound", "binds", "Gene"): "CbG",
    ("Compound", "causes", "Side Effect"): "CcSE",
    ("Compound", "downregulates", "Gene"): "CdG",
    ("Compound", "palliates", "Disease"): "CpD",
    ("Compound", "resembles", "Compound"): "CrC",
    ("Compound", "treats", "Disease"): "CtD",
    ("Compound", "upregulates", "Gene"): "CuG",
    ("Disease", "associates", "Gene"): "DaG",
    ("Disease", "downregulates", "Gene"): "DdG",
    ("Disease", "localizes", "Anatomy"): "DlA",
    ("Disease", "presents", "Symptom"): "DpS",
    ("Disease", "resembles", "Disease"): "DrD",
    ("Disease", "upregulates", "Gene"): "DuG",
    ("Gene", "covaries", "Gene"): "GcG",
    ("Gene", "interacts", "Gene"): "GiG",
    ("Gene", "participates", "Biological Process"): "GpBP",
    ("Gene", "participates", "Cellular Component"): "GpCC",
    ("Gene", "participates", "Molecular Function"): "GpMF",
    ("Gene", "participates", "Pathway"): "GpPW",
    ("Gene", "regulates", "Gene"): "GrG",
    ("Pharmacologic Class", "includes", "Compound"): "PCiC"
}

abbr_map = {v: k for k, v in abbr_dict.items()}

def split_etype(etype_str):
    return abbr_map[etype_str]

def load_graph(edges_path = "data/edges"):

    edge_index_dict = {}

    # nacitanie edges
    for edges_file in os.listdir(edges_path):
        with np.load(f"{edges_path}/{edges_file}") as data:
            etype_str = edges_file.replace(".sparse.npz", "").replace(">", "")

            is_bidirectional = ">" not in edges_file
            edata = (data["indptr"], data["indices"], [])
            etype = split_etype(etype_str)
            print(etype)

            edge_index_dict[etype] = ('csc', edata)
            if is_bidirectional:
                edge_index_dict[etype[::-1]] = ('csr', edata)

    return dgl.heterograph(edge_index_dict)

def remove_random_edges(graph, ratio_to_remove, edge_types): 
    removed_eids = dict()
    for etype in edge_types:
        eids = graph.edges(form="eid", etype=etype).numpy()
        shuffle(eids)
        eids_to_remove = eids[:int(eids.size * (ratio_to_remove))]
        removed_eids[etype] = eids_to_remove
        graph = dgl.remove_edges(graph, eids_to_remove, etype)

    return graph, removed_eids

def remove_edges(graph, id_dict): 
    for etype, ids in id_dict.items():
        graph = dgl.remove_edges(graph, ids, etype)

    return graph

class HetioNetDataset(DGLDataset):
    def __init__(self, graph, batch_size=1, name="hetiograph", data_len=1000, perturbation_fn=DropEdge()):
        super().__init__(name)
        self.graph = graph
        self.data_len = data_len
        self.batch_size = batch_size
        self.perturbation = perturbation_fn

    def __getitem__(self, idx):
        return dgl.batch([self.perturbation(self.graph) for _ in range(self.batch_size)]) 

    def __len__(self):
        return self.data_len


class DownstreamDataset(torch.utils.data.Dataset):
    def __init__(self, graph, edge_ids, n_negative_samples=5, split_type="test", test_ratio=0.1):
        super().__init__()
        
        positive_edges = []
         
        # split to test/train set
        for etype, ids in edge_ids.items():
            if split_type == "test":
                positive_edges.append((etype, ids[int(len(ids) * (1 - test_ratio)):]))
            elif split_type == "val":
                positive_edges.append((etype, ids[int(len(ids) * (1 - 2 * test_ratio)):int(len(ids) * (1 - test_ratio))]))
            elif split_type == "train":
                positive_edges.append((etype, ids[:int(len(ids) * (1 - 2 * test_ratio))]))
            else:
                raise ValueError("wrong split_type str")
        

                        
        self.node_pairs = [] 
        self.labels = []
        curr_label = 1
        # create triples for each etype
        for etype, ids in positive_edges:
            all_edges = set(zip(*graph.edges(etype=etype)))

            # create triples
            for id in ids:
                # create positive triple
                src, dst = graph.find_edges(id, etype=etype)
                src, dst = src.item(), dst.item()
                self.node_pairs.append((src, dst))
                self.labels.append(curr_label)

                # sample negative triple
                for _ in range(n_negative_samples):
                    if random.randint(0, 1) == 0:
                        new_src = random.randint(0, len(graph.nodes(etype[0])) - 1)
                        while (new_src, dst) in all_edges:
                            new_src = random.randint(0, len(graph.nodes(etype[0])) - 1)

                        self.node_pairs.append((new_src, dst))
                    else:
                        new_dst = random.randint(0, len(graph.nodes(etype[2])) - 1)
                        while (src, new_dst) in all_edges:
                            new_dst = random.randint(0, len(graph.nodes(etype[2])) - 1)

                        self.node_pairs.append((src, new_dst))
                    self.labels.append(0) 
            curr_label += 1
            

    def __getitem__(self, idx):
        return self.node_pairs[idx], self.labels[idx]

    def __len__(self):
        return len(self.node_pairs) 