Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn as nn | |
| from .layer import GVP, GVPConvLayer, LayerNorm | |
| from torch_scatter import scatter_mean | |
| class AttentionPooling(nn.Module): | |
| def __init__(self, input_dim, attention_dim): | |
| super(AttentionPooling, self).__init__() | |
| self.attention_dim = attention_dim | |
| self.query_layer = nn.Linear(input_dim, attention_dim, bias=True) | |
| self.key_layer = nn.Linear(input_dim, attention_dim, bias=True) | |
| self.value_layer = nn.Linear(input_dim, 1, bias=True) # value layer outputs one score | |
| self.softmax = nn.Softmax(dim=1) | |
| def forward(self, nodes_features1, nodes_features2): | |
| # Assuming nodes_features1 and nodes_features2 are both of shape [node_num, 128] | |
| nodes_features = nodes_features1 + nodes_features2 # This can also be concatenation or another operation | |
| query = self.query_layer(nodes_features) | |
| key = self.key_layer(nodes_features) | |
| value = self.value_layer(nodes_features) | |
| attention_scores = torch.matmul(query, key.transpose(-2, -1)) # [node_num, node_num] | |
| attention_scores = self.softmax(attention_scores) | |
| pooled_features = torch.matmul(attention_scores, value) # [node_num, 1] | |
| return pooled_features | |
| class AutoGraphEncoder(nn.Module): | |
| def __init__(self, node_in_dim, node_h_dim, | |
| edge_in_dim, edge_h_dim, attention_dim=64, | |
| num_layers=4, drop_rate=0.1) -> None: | |
| super().__init__() | |
| self.W_v = nn.Sequential( | |
| LayerNorm(node_in_dim), | |
| GVP(node_in_dim, node_h_dim, activations=(None, None)) | |
| ) | |
| self.W_e = nn.Sequential( | |
| LayerNorm(edge_in_dim), | |
| GVP(edge_in_dim, edge_h_dim, activations=(None, None)) | |
| ) | |
| self.layers = nn.ModuleList( | |
| GVPConvLayer(node_h_dim, edge_h_dim, drop_rate=drop_rate) | |
| for _ in range(num_layers)) | |
| ns, _ = node_h_dim | |
| self.W_out = nn.Sequential( | |
| LayerNorm(node_h_dim), | |
| GVP(node_h_dim, (ns, 0))) | |
| self.dense = nn.Sequential( | |
| nn.Linear(ns, 2*ns), | |
| nn.ReLU(inplace=True), | |
| nn.Dropout(p=drop_rate), | |
| nn.Linear(2*ns, node_in_dim[0]) # label num | |
| ) | |
| self.loss_fn = nn.CrossEntropyLoss() | |
| def forward(self, h_V, edge_index, h_E, node_s_labels): | |
| h_V = self.W_v(h_V) | |
| h_E = self.W_e(h_E) | |
| for layer in self.layers: | |
| h_V = layer(h_V, edge_index, h_E) | |
| out = self.W_out(h_V) | |
| logits = self.dense(out) | |
| loss = self.loss_fn(logits, node_s_labels) | |
| return loss, logits | |
| def get_embedding(self, h_V, edge_index, h_E): | |
| h_V = self.W_v(h_V) | |
| h_E = self.W_e(h_E) | |
| for layer in self.layers: | |
| h_V = layer(h_V, edge_index, h_E) | |
| out = self.W_out(h_V) | |
| return out | |
| class SubgraphClassficationModel(nn.Module): | |
| ''' | |
| :param node_in_dim: node dimensions in input graph, should be | |
| (6, 3) if using original features | |
| :param node_h_dim: node dimensions to use in GVP-GNN layers | |
| :param edge_in_dim: edge dimensions in input graph, should be | |
| (32, 1) if using original features | |
| :param edge_h_dim: edge dimensions to embed to before use | |
| in GVP-GNN layers | |
| :param num_layers: number of GVP-GNN layers | |
| :param drop_rate: rate to use in all dropout layers | |
| ''' | |
| def __init__(self, node_in_dim, node_h_dim, | |
| edge_in_dim, edge_h_dim, attention_dim=64, | |
| num_layers=4, drop_rate=0.1): | |
| super(SubgraphClassficationModel, self).__init__() | |
| self.W_v = nn.Sequential( | |
| LayerNorm(node_in_dim), | |
| GVP(node_in_dim, node_h_dim, activations=(None, None)) | |
| ) | |
| self.W_e = nn.Sequential( | |
| LayerNorm(edge_in_dim), | |
| GVP(edge_in_dim, edge_h_dim, activations=(None, None)) | |
| ) | |
| self.layers = nn.ModuleList( | |
| GVPConvLayer(node_h_dim, edge_h_dim, drop_rate=drop_rate) | |
| for _ in range(num_layers)) | |
| ns, _ = node_h_dim | |
| self.W_out = nn.Sequential( | |
| LayerNorm(node_h_dim), | |
| GVP(node_h_dim, (ns, 0))) | |
| self.attention_classifier = AttentionPooling(ns, attention_dim) | |
| self.dense = nn.Sequential( | |
| nn.Linear(2*ns, 2*ns), | |
| nn.ReLU(inplace=True), | |
| nn.Dropout(p=drop_rate), | |
| nn.Linear(2*ns, 1) | |
| ) | |
| self.loss_fn = nn.BCEWithLogitsLoss() | |
| def forward(self, h_V_parent, edge_index_parent, h_E_parent, batch_parent, | |
| h_V_subgraph, edge_index_subgraph, h_E_subgraph, batch_subgraph, | |
| labels): | |
| ''' | |
| :param h_V: tuple (s, V) of node embeddings | |
| :param edge_index: `torch.Tensor` of shape [2, num_edges] | |
| :param h_E: tuple (s, V) of edge embeddings | |
| ''' | |
| h_V_parent = self.W_v(h_V_parent) | |
| h_E_parent = self.W_e(h_E_parent) | |
| for layer in self.layers: | |
| h_V_parent = layer(h_V_parent, edge_index_parent, h_E_parent) | |
| out_parent = self.W_out(h_V_parent) | |
| out_parent = scatter_mean(out_parent, batch_parent, dim=0) | |
| h_V_subgraph = self.W_v(h_V_subgraph) | |
| h_E_subgraph = self.W_e(h_E_subgraph) | |
| for layer in self.layers: | |
| h_V_subgraph = layer(h_V_subgraph, edge_index_subgraph, h_E_subgraph) | |
| out_subgraph = self.W_out(h_V_subgraph) | |
| out_subgraph = scatter_mean(out_subgraph, batch_subgraph, dim=0) | |
| labels = labels.float() | |
| out = torch.cat([out_parent, out_subgraph], dim=1) | |
| logits = self.dense(out) | |
| # logits = self.attention_classifier(out_parent, out_subgraph) | |
| loss = self.loss_fn(logits.squeeze(-1), labels) | |
| return loss, logits | |
| def get_embedding(self, h_V, edge_index, h_E, batch): | |
| h_V = self.W_v(h_V) | |
| h_E = self.W_e(h_E) | |
| for layer in self.layers: | |
| h_V = layer(h_V, edge_index, h_E) | |
| out = self.W_out(h_V) | |
| out = scatter_mean(out, batch, dim=0) | |
| return out |