summaryrefslogtreecommitdiff
path: root/models/graphcnn.py
diff options
context:
space:
mode:
Diffstat (limited to 'models/graphcnn.py')
-rw-r--r--models/graphcnn.py229
1 files changed, 229 insertions, 0 deletions
diff --git a/models/graphcnn.py b/models/graphcnn.py
new file mode 100644
index 0000000..edf2c51
--- /dev/null
+++ b/models/graphcnn.py
@@ -0,0 +1,229 @@
+#coding:utf-8
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+import sys
+sys.path.append("models/")
+from mlp import MLP
+
+class GraphCNN(nn.Module):
+ def __init__(self, num_layers, num_mlp_layers, input_dim, hidden_dim, output_dim, final_dropout, learn_eps, graph_pooling_type, neighbor_pooling_type, device):
+ '''
+ num_layers: number of layers in the neural networks (INCLUDING the input layer)
+ num_mlp_layers: number of layers in mlps (EXCLUDING the input layer)
+ input_dim: dimensionality of input features
+ hidden_dim: dimensionality of hidden units at ALL layers
+ output_dim: number of classes for prediction
+ final_dropout: dropout ratio on the final linear layer
+ learn_eps: If True, learn epsilon to distinguish center nodes from neighboring nodes. If False, aggregate neighbors and center nodes altogether.
+ neighbor_pooling_type: how to aggregate neighbors (mean, average, or max)
+ graph_pooling_type: how to aggregate entire nodes in a graph (mean, average)
+ device: which device to use
+ '''
+
+ super(GraphCNN, self).__init__()
+
+ self.final_dropout = final_dropout
+ self.device = device
+ self.num_layers = num_layers
+ self.graph_pooling_type = graph_pooling_type
+ self.neighbor_pooling_type = neighbor_pooling_type
+ self.learn_eps = learn_eps
+ self.eps = nn.Parameter(torch.zeros(self.num_layers-1))
+
+ ###List of MLPs
+ self.mlps = torch.nn.ModuleList()
+
+ ###List of batchnorms applied to the output of MLP (input of the final prediction linear layer)
+ self.batch_norms = torch.nn.ModuleList()
+
+ for layer in range(self.num_layers-1):
+ if layer == 0:
+ self.mlps.append(MLP(num_mlp_layers, input_dim, hidden_dim, hidden_dim))
+ else:
+ self.mlps.append(MLP(num_mlp_layers, hidden_dim, hidden_dim, hidden_dim))
+
+ self.batch_norms.append(nn.BatchNorm1d(hidden_dim))
+
+ #Linear function that maps the hidden representation at dofferemt layers into a prediction score
+ self.linears_prediction = torch.nn.ModumleList()
+ for layer in range(num_layers):
+ if layer == 0:
+ self.linears_prediction.append(nn.Linear(input_dim, output_dim))
+ else:
+ self.linears_prediction.append(nn.Linear(hidden_dim, output_dim))
+
+
+ def __preprocess_neighbors_maxpool(self, batch_graph):
+ ###create padded_neighbor_list in concatenated graph
+
+ #compute the maximum number of neighbors within the graphs in the current minibatch
+ max_deg = max([graph.max_neighbor for graph in batch_graph])
+
+ padded_neighbor_list = []
+ start_idx = [0]
+
+
+ for i, graph in enumerate(batch_graph):
+ start_idx.append(start_idx[i] + len(graph.g))
+ padded_neighbors = []
+ for j in range(len(graph.neighbors)):
+ #add off-set values to the neighbor indices
+ pad = [n + start_idx[i] for n in graph.neighbors[j]]#j的邻居的位置
+ #padding, dummy data is assumed to be stored in -1
+ pad.extend([-1]*(max_deg - len(pad)))#-1*(邻居最大的度-自己的度)
+
+ #Add center nodes in the maxpooling if learn_eps is False, i.e., aggregate center nodes and neighbor nodes altogether.
+ if not self.learn_eps:
+ pad.append(j + start_idx[i])
+
+ padded_neighbors.append(pad)
+ padded_neighbor_list.extend(padded_neighbors)
+
+ return torch.LongTensor(padded_neighbor_list)
+
+
+ def __preprocess_neighbors_sumavepool(self, batch_graph):
+ ###create block diagonal sparse matrix
+
+ edge_mat_list = []
+ start_idx = [0]
+ for i, graph in enumerate(batch_graph):
+ start_idx.append(start_idx[i] + len(graph.g))
+ edge_mat_list.append(graph.edge_mat + start_idx[i])
+ Adj_block_idx = torch.cat(edge_mat_list, 1)
+ Adj_block_elem = torch.ones(Adj_block_idx.shape[1])
+
+ #Add self-loops in the adjacency matrix if learn_eps is False, i.e., aggregate center nodes and neighbor nodes altogether.
+
+ if not self.learn_eps:
+ num_node = start_idx[-1]
+ self_loop_edge = torch.LongTensor([range(num_node), range(num_node)])
+ elem = torch.ones(num_node)
+ Adj_block_idx = torch.cat([Adj_block_idx, self_loop_edge], 1)#(边的个数+9000,2),前边的个数个边的关系,后9000个0-8999
+ Adj_block_elem = torch.cat([Adj_block_elem, elem], 0)#前边的个数个1,后9000个1
+ #存在边的位置是1,00,11,22这样的位置是1的矩阵
+ Adj_block = torch.sparse.FloatTensor(Adj_block_idx, Adj_block_elem, torch.Size([start_idx[-1],start_idx[-1]]))
+
+ return Adj_block.to(self.device)
+
+
+ def __preprocess_graphpool(self, batch_graph):
+ ###create sum or average pooling sparse matrix over entire nodes in each graph (num graphs x num nodes)
+
+ start_idx = [0]
+
+ #compute the padded neighbor list
+ for i, graph in enumerate(batch_graph):
+ start_idx.append(start_idx[i] + len(graph.g))
+
+ idx = []
+ elem = []
+ for i, graph in enumerate(batch_graph):
+ ###average pooling
+ if self.graph_pooling_type == "average":
+ elem.extend([1./len(graph.g)]*len(graph.g))
+
+ else:
+ ###sum pooling
+ a=[1]*len(graph.g)
+ elem.extend(a)
+
+ idx.extend([[i, j] for j in range(start_idx[i], start_idx[i+1], 1)])
+ elem = torch.FloatTensor(elem)
+ idx = torch.LongTensor(idx).transpose(0,1)#变成了一个(9000,9000)的矩阵,行是第几个graph 0*30,1*30,。。。299*30,列是0,。。。,8999#每幅图点的位置
+ graph_pool = torch.sparse.FloatTensor(idx, elem, torch.Size([len(batch_graph), start_idx[-1]]))#每幅图点的位置都是1
+ #构造了一个矩阵,每行前len(graph.g)个元素是 (300,9000)
+ return graph_pool.to(self.device)
+
+ def maxpool(self, h, padded_neighbor_list):
+ ###Element-wise minimum will never affect max-pooling
+
+ dummy = torch.min(h, dim = 0)[0]
+ h_with_dummy = torch.cat([h, dummy.reshape((1, -1)).to(self.device)])
+ pooled_rep = torch.max(h_with_dummy[padded_neighbor_list], dim = 1)[0]
+ return pooled_rep
+
+
+ def next_layer_eps(self, h, layer, padded_neighbor_list = None, Adj_block = None):
+ ###pooling neighboring nodes and center nodes separately by epsilon reweighting.
+
+ if self.neighbor_pooling_type == "max":
+ ##If max pooling
+ pooled = self.maxpool(h, padded_neighbor_list)
+ else:
+ #If sum or average pooling
+ pooled = torch.spmm(Adj_block, h)
+ if self.neighbor_pooling_type == "average":
+ #If average pooling
+ degree = torch.spmm(Adj_block, torch.ones((Adj_block.shape[0], 1)).to(self.device))
+ pooled = pooled/degree
+
+ #Reweights the center node representation when aggregating it with its neighbors
+ pooled = pooled + (1 + self.eps[layer])*h
+ pooled_rep = self.mlps[layer](pooled)
+ h = self.batch_norms[layer](pooled_rep)
+
+ #non-linearity
+ h = F.relu(h)
+ return h
+
+
+ def next_layer(self, h, layer, padded_neighbor_list = None, Adj_block = None):
+ ###pooling neighboring nodes and center nodes altogether
+
+ if self.neighbor_pooling_type == "max":
+ ##If max pooling
+ pooled = self.maxpool(h, padded_neighbor_list)
+ else:
+ #If sum or average pooling
+ pooled = torch.spmm(Adj_block, h)
+ if self.neighbor_pooling_type == "average":
+ #If average pooling
+ degree = torch.spmm(Adj_block, torch.ones((Adj_block.shape[0], 1)).to(self.device))
+ pooled = pooled/degree
+
+ #representation of neighboring and center nodes
+ pooled_rep = self.mlps[layer](pooled)
+
+ h = self.batch_norms[layer](pooled_rep)
+
+ #non-linearity
+ h = F.relu(h)
+ return h
+
+
+ def forward(self, batch_graph):
+ X_concat = torch.cat([graph.node_features for graph in batch_graph], 0).to(self.device)
+ graph_pool = self.__preprocess_graphpool(batch_graph)
+
+ if self.neighbor_pooling_type == "max":
+ padded_neighbor_list = self.__preprocess_neighbors_maxpool(batch_graph)
+ else:
+ Adj_block = self.__preprocess_neighbors_sumavepool(batch_graph)
+
+ #list of hidden representation at each layer (including input)
+ hidden_rep = [X_concat]
+ h = X_concat
+
+ for layer in range(self.num_layers-1):
+ if self.neighbor_pooling_type == "max " and self.learn_eps:
+ h = self.next_layer_eps(h, layer, padded_neighbor_list = padded_neighbor_list)
+ elif not self.neighbor_pooling_type == "max" and self.learn_eps:
+ h = self.next_layer_eps(h, layer, Adj_block = Adj_block)
+ elif self.neighbor_pooling_type == "max" and not self.learn_eps:
+ h = self.next_layer(h, layer, padded_neighbor_list = padded_neighbor_list)
+ elif not self.neighbor_pooling_type == "max" and not self.learn_eps:
+ h = self.next_layer(h, layer, Adj_block = Adj_block)
+
+ hidden_rep.append(h)
+
+ score_over_layer = 0
+
+ #perform pooling over all nodes in each graph in every layer
+ for layer, h in enumerate(hidden_rep):
+ pooled_h = torch.spmm(graph_pool, h)
+ score_over_layer += F.dropout(self.linears_prediction[layer](pooled_h), self.final_dropout, training = self.training)
+
+ return score_over_layer