# Load data

Uncomment the code bellow to download data.

In [1]:
#!wget https://github.com/hetio/hetionet/raw/main/hetnet/matrix/hetionet-v1.0.hetmat.zip
#!7za x hetionet-v1.0.hetmat.zip -odata

In [1]:
import torch
from RelationalLearning.DataLoader import load_graph

graph = load_graph()

dev = "cuda" if torch.cuda.is_available() else "cpu"


('Disease', 'presents', 'Symptom')
('Anatomy', 'downregulates', 'Gene')
('Compound', 'binds', 'Gene')
('Disease', 'associates', 'Gene')
('Gene', 'participates', 'Cellular Component')
('Gene', 'participates', 'Pathway')
('Gene', 'regulates', 'Gene')
('Disease', 'localizes', 'Anatomy')
('Anatomy', 'upregulates', 'Gene')
('Compound', 'palliates', 'Disease')
('Compound', 'causes', 'Side Effect')
('Gene', 'participates', 'Biological Process')
('Compound', 'upregulates', 'Gene')
('Compound', 'resembles', 'Compound')
('Gene', 'covaries', 'Gene')
('Gene', 'interacts', 'Gene')
('Disease', 'upregulates', 'Gene')
('Pharmacologic Class', 'includes', 'Compound')
('Gene', 'participates', 'Molecular Function')
('Compound', 'downregulates', 'Gene')
('Disease', 'downregulates', 'Gene')
('Disease', 'resembles', 'Disease')
('Anatomy', 'expresses', 'Gene')
('Compound', 'treats', 'Disease')


## Initiliaze node features

In [2]:
from RelationalLearning.utils import ntype_one_hot_init

ntype_one_hot_init(graph)
in_feature_size = 11

## Split dataset

Dictionary of edges for downstream dataset. Contains randomly chosen 50% of edges of given types from original graph.

In [3]:
eids_to_remove ={('Compound',
                  'treats',
                  'Disease'): [ 21, 336, 291,  77, 237, 749,  32, 161,  75, 147, 514, 139, 460,
                                538, 273, 511, 244, 724, 662, 351, 381, 153, 279, 315, 501, 333,
                                491, 659, 526, 173,  36, 660, 697, 670,  58, 293,  64, 145, 259,
                                331, 330, 246, 704,  20, 150, 532, 166, 728, 609, 226, 635, 398,
                                644, 530, 207, 481, 505, 695, 421, 223, 629, 506, 642, 126, 451,
                                535, 754, 174, 420, 624, 157, 186, 680, 574,   7, 596, 521, 397,
                                172, 540, 567, 180, 747, 200, 753, 190, 545, 272, 309,  94, 630,
                                151, 664, 342, 392, 477, 209, 368,  70, 338,  85, 106, 546, 713,
                                683, 539, 443,  38,  91, 577, 107, 566, 300, 557, 193, 158, 235,
                                479, 619, 135, 445, 727, 417, 447, 484,  27,  83, 188, 414, 575,
                                60, 677, 721, 212, 418,  49, 666, 322, 327, 458, 499, 655, 427,
                                498,   6, 376,  11,  45, 613, 650, 592, 254, 738,  71,  81, 723,
                                387, 364, 115, 692, 372,  90, 409,  72, 425, 168, 298, 320, 127,
                                571, 442, 137, 136, 502, 366, 185, 752, 469, 522, 198,  86, 544,
                                8, 385, 165, 335, 108, 733, 120, 271, 523, 676, 579, 488, 433,
                                356, 718, 520,  14, 205, 296, 652, 547, 494, 357, 288, 346, 100,
                                113, 308,   1, 441, 275, 614,   4, 578, 450,  25, 382, 359, 700,
                                553, 403, 474,  41,  46, 656, 408, 653, 103, 274, 369, 146, 175,
                                131, 492, 508, 179, 416, 454, 497, 389, 292, 674, 101, 473, 512,
                                541, 632,  95, 472, 163, 453, 679, 117,  65, 702, 437,  56, 696,
                                595, 688,  39, 324,  13, 590, 238, 562, 634, 255, 572, 299,  62,
                                467, 313,  82,  23, 503,   2, 170, 543, 365, 321, 218, 582, 661,
                                74,   3, 542, 178,  80, 345, 672, 305, 225, 281, 339, 739, 513,
                                714, 370, 536, 628, 563, 534,  96,  79, 438, 604, 736, 183, 360,
                                586, 162, 239, 751, 465, 565, 439, 363, 456, 343, 449, 568, 353,
                                16, 182, 712, 611, 269, 375, 717, 734,  63, 455, 196, 558, 627,
                                52, 649, 277, 435, 527, 603, 725, 691, 134, 140, 531, 731, 401,
                                110, 264, 720, 386, 412, 690, 570, 518, 519, 220, 233, 475,  89,
                                675, 606, 119, 132, 287,  68,  55, 129, 230, 478, 487, 740, 496],
                  ('Compound',
                  'palliates',
                  'Disease'): [317,  22, 139, 198, 181, 368, 320,  48, 230, 155,  90, 372, 177,
                               339,  59, 302, 330, 336, 207, 120, 165,  23, 118,  47, 182, 334,
                               340, 356, 125, 328,  44,  19, 174,  49, 201, 200, 150, 146, 349,
                               185,   2, 240, 314, 107,  33, 331, 238, 128, 290, 347, 257, 343,
                               153, 383,  16, 218, 247, 157,  61,  78, 374, 101, 378, 233,  41,
                               268,   6,  99, 127, 245, 136, 178, 361, 332, 363, 326, 131,  60,
                               95, 184, 256, 252, 183, 376, 253,   7, 197,  55,  30, 282, 362,
                               96, 386, 350, 277, 338, 137, 160, 273, 217, 318, 180,  79,  65,
                               158, 116,  52,  97, 271, 210,   3, 309, 172, 237, 264,  34, 266,
                               4, 227,  94,  75, 103, 192, 170, 145, 333,  12, 259,  36, 206,
                               154, 134,  37,  66, 203, 189, 104, 360, 303, 254, 344,   9, 289,
                               285,  18, 366, 191, 382, 225, 224,  13, 325, 179, 286, 119, 167,
                               24, 151,   5, 106,  71,   0,  82, 108, 312, 284, 243, 152,  10,
                               196, 173,  29,  87, 371, 364, 171, 246, 294, 208, 345, 365,  77,
                               215,  67, 216, 295, 109, 121,  42, 132, 163, 112,  31, 244,  17]} 

In [4]:
from RelationalLearning.DataLoader import remove_edges

input_graph = remove_edges(graph, eids_to_remove)

In [5]:
from RelationalLearning.DataLoader import HetioNetDataset, DownstreamDataset
from RelationalLearning.utils import dec_pert
import torch

import random
import dgl
import numpy as np

random.seed(42)
dgl.seed(42)
np.random.seed(42)
torch.manual_seed(42)
torch.cuda.manual_seed_all(42)

pretext_enc_dataset = HetioNetDataset(input_graph)
pretext_dec_dataset = HetioNetDataset(input_graph, perturbation_fn=dec_pert(.2))

downstream_enc_dataset = HetioNetDataset(input_graph, perturbation_fn=lambda x: x)
train_downstream_dataset = DownstreamDataset(graph, eids_to_remove, n_negative_samples=1, test_ratio=0.1, split_type="train")
val_downstream_dataset = DownstreamDataset(graph, eids_to_remove, n_negative_samples=1, test_ratio=0.1, split_type="val")
test_downstream_dataset = DownstreamDataset(graph, eids_to_remove, n_negative_samples=1, test_ratio=0.1, split_type="test")

train_downstream_loader = torch.utils.data.DataLoader(train_downstream_dataset, batch_size=32, shuffle=True)
val_downstream_loader = torch.utils.data.DataLoader(val_downstream_dataset, batch_size=len(val_downstream_dataset), shuffle=False)
test_downstream_loader = torch.utils.data.DataLoader(test_downstream_dataset, batch_size=len(test_downstream_dataset), shuffle=False)

# Models' definitions

In [8]:

from torch import nn
from dgl.nn.pytorch import HeteroGraphConv, SAGEConv
from torch.nn import Linear, Sequential, ReLU, Softmax
import torch.nn.functional as F

class Encoder(nn.Module):
    def __init__(self, graph, in_size, n_feat) -> None:
        super(Encoder, self).__init__()
        convolutions_config_1 = {}
        for etype in set(graph.etypes):
            convolutions_config_1[etype] = SAGEConv(in_size, n_feat, "gcn", feat_drop=0.5)

        convolutions_config_2 = {}
        for etype in set(graph.etypes):
            convolutions_config_2[etype] = SAGEConv(n_feat, n_feat, "pool", feat_drop=0.5)

        # convolution layers for node representation updates
        self.conv1 = HeteroGraphConv(convolutions_config_1, aggregate="mean")
        self.conv2 = HeteroGraphConv(convolutions_config_2, aggregate="mean")

    def forward(self, g, in_feat):
        # apply first conv layer on node representations
        h_n = self.conv1(g, in_feat)

        # apply relu and second conv layer on node representations
        h_n = self.conv2(g, {ntype: F.relu(h_i) for ntype, h_i in h_n.items()})

        return h_n


class PretextDecoder(torch.nn.Module):
    def __init__(self, graph, in_size, n_feat) -> None:
        super(PretextDecoder, self).__init__()

        # feed-forward network for each edge type
        self.etype_mlp = torch.nn.ModuleDict()
        for etype in set(graph.etypes):
            self.etype_mlp[etype] = Sequential(
                Linear(in_size, n_feat),
                nn.BatchNorm1d(n_feat),
                nn.Dropout(),
                Linear(n_feat, n_feat),
                ReLU(),
                Linear(n_feat, 1)
            )

    def forward(self, g, in_feat):
        with g.local_scope():
            # Assign the input features to the nodes
            g.ndata["f"] = in_feat

            # Compute the score for each edge type using the corresponding MLP
            for rel in g.canonical_etypes:
                mlp = self.etype_mlp[rel[1]]
                g.apply_edges(lambda edges: {"scores": mlp(torch.cat((edges.src["f"], edges.dst["f"]), 1))}, etype=rel)

            return g.edata["scores"]


class DownstreamDecoder(torch.nn.Module):
    def __init__(self, in_size, n_feat) -> None:
        super(DownstreamDecoder, self).__init__()

        self.mlp = Sequential(
            Linear(in_size, n_feat), 
            nn.Dropout(),
            nn.ReLU(),
            Linear(n_feat, 3)
        )

    def forward(self, src_nodes, dst_nodes, in_feat):
        h = [torch.cat([in_feat["Compound"][src_nodes[i].item()], in_feat["Disease"][dst_nodes[i].item()]])
             for i in range(len(src_nodes))]

        pred = self.mlp(torch.stack(h))

        return pred

# Training

In [9]:
embeddings_size = 32

encoder = Encoder(graph=input_graph,
                  in_size=in_feature_size,
                  n_feat=embeddings_size)

## Pretext

In [None]:
import itertools
from RelationalLearning.utils import train_pretext, compute_pretext_loss

pretext_decoder = PretextDecoder(graph=input_graph,
                                 in_size=embeddings_size * 2,
                                 n_feat=64)

optimizer = torch.optim.Adam(
    itertools.chain(encoder.parameters(), pretext_decoder.parameters()),
    lr=0.01
)

train_pretext(encoder=encoder,
              decoder=pretext_decoder,
              enc_dataset=pretext_enc_dataset,
              dec_dataset=pretext_dec_dataset,
              optimizer=optimizer,
              loss_fn=compute_pretext_loss,
              epochs=150,
              dev=dev)

## Downstream

In [None]:
from RelationalLearning.utils import train_downstream

downstream_decoder = DownstreamDecoder(in_size=embeddings_size * 2,
                                       n_feat=32)

optimizer = torch.optim.Adam(
    itertools.chain(encoder.parameters(), downstream_decoder.parameters()),
    lr=0.002
)

train_downstream(encoder=encoder,
                 decoder=downstream_decoder,
                 enc_dataset=downstream_enc_dataset,
                 dec_dataset=train_downstream_loader,
                 val_dataset=val_downstream_loader, 
                 optimizer=optimizer,
                 epochs=17,
                 dev=dev)

# Evaluation

In [None]:
from RelationalLearning.utils import evaluate

encoder.eval()
downstream_decoder.eval()

evaluate(test_downstream_loader, encoder.to(dev), downstream_decoder.to(dev), input_graph.to(dev), dev, plot=True)