RGCN layer implementation in torch
- 将含有 N 种edge type的多关系图,拆成 N 个 同质图,每个同质图只包含一种edge type及该edge type所连接到的节点
- 对每一种edgetype初始化一个GCN层,每个GCN学习一种关系的同质图
- 对于每个节点,聚合各个关系下学习到的节点特征(aggregation = sum+mean)
class GraphConvolutionLayer(Module):
def __init__(self, in_features, out_features, activation, edge_type_num, dropout_rate=0.):
super(GraphConvolutionLayer, self).__init__()
self.edge_type_num = edge_type_num
self.out_features = out_features
self.edgeType_linears = nn.ModuleList()
for _ in range(self.edge_type_num):
self.edgeType_linears.append(nn.Linear(in_features, out_features))
self.linear_2 = nn.Linear(in_features, out_features)
self.activation = activation
self.dropout = nn.Dropout(dropout_rate)
def forward(self, n_tensor, adj_tensor, h_tensor=None):
if h_tensor is not None:
node_annotations = torch.cat((n_tensor, h_tensor), -1)
else:
node_annotations = n_tensor
node_feat_all_edge_type = []
for edge_type in range(self.edge_type_num):
node_feat_single_edge_type = self.edgeType_linears[edge_type](node_annotations)
node_feat_all_edge_type.append(node_feat_single_edge_type)
output = torch.stack(node_feat_all_edge_type, dim=1)
output = torch.matmul(adj_tensor, output)
out_sum = torch.sum(input=output, dim=1)
node_self_annotation = self.linear_2(node_annotations)
output = out_sum + node_self_annotation
output = self.activation(output) if self.activation is not None else output
output = self.dropout(output)
return output