GCN与RNN的结合体:以此捕捉动态图的动态特征。
GCN的作用: 提取某一时刻图节点的特征,但是不对参数进行修改。
RNN的作用: 结合相邻时刻图节点的特征更新GCN的参数。
-H 更关注节点上的信息,-O 更关注图整体结构的变化。
代码实现
import torch
import numpy as np
np.set_printoptions(suppress=True)
import torch.nn as nn
from torch.nn import init
from dgl.nn.pytorch import GraphConv, HeteroGraphConv
from torch.nn.parameter import Parameter
import warnings
warnings.filterwarnings("ignore")
class MatGRUGate(torch.nn.Module):
"""
GRU gate for matrix, similar to the official code.
Please refer to section 3.4 of the paper for the formula.
"""
def __init__(self, rows, cols, activation): #行 列 激活函数
super().__init__()
self.activation = activation
self.W = Parameter(torch.Tensor(rows, rows))
self.U = Parameter(torch.Tensor(rows, rows))
self.bias = Parameter(torch.Tensor(rows, cols))
self.reset_parameters()
def reset_parameters(self): #初始化参数
init.xavier_uniform_(self.W)
init.xavier_uniform_(self.U)
init.zeros_(self.bias)
def forward(self, x, hidden):
out = self.activation(self.W.matmul(x) + \
self.U.matmul(hidden) + \
self.bias)
return out
class MatGRUCell(torch.nn.Module):
"""
GRU cell for matrix, similar to the official code.
Please refer to section 3.4 of the paper for the formula.
"""
def __init__(self, in_feats, out_feats):
super().__init__()
self.update = MatGRUGate(in_feats,
out_feats,
torch.nn.Sigmoid())
self.reset = MatGRUGate(in_feats,
out_feats,
torch.nn.Sigmoid())
self.htilda = MatGRUGate(in_feats,
out_feats,
torch.nn.Tanh())
def forward(self, prev_Q, z_topk=None):
if z_topk is None:
z_topk = prev_Q
update = self.update(z_topk, prev_Q)
reset = self.reset(z_topk, prev_Q)
h_cap = reset * prev_Q
h_cap = self.htilda(z_topk, h_cap)
new_Q = (1 - update) * prev_Q + update * h_cap
return new_Q
class EGCN(nn.Module):
def __init__(self, layers = []): # num_layers=2 n_classes=2 输出维度设为6
super(EGCN, self).__init__()
self.layers_list = layers
self.num_layers = len(self.layers_list)-2
self.recurrent_layers = nn.ModuleList() # rnn
self.gnn_convs = nn.ModuleList() # 卷积层
self.gcn_weights_list = nn.ParameterList() # 权重
for i in range(self.num_layers + 1):
if i == self.num_layers:
self.mlp = nn.Sequential(nn.Linear(self.layers_list[i], 50),
nn.ReLU(),
nn.Linear(50, self.layers_list[i+1]),
nn.Softmax())
else:
self.recurrent_layers.append(
MatGRUCell(in_feats=self.layers_list[i],
out_feats=self.layers_list[i + 1]))
self.gcn_weights_list.append(
Parameter(torch.Tensor(self.layers_list[i],
self.layers_list[i + 1])))
self.gnn_convs.append(
GraphConv(in_feats=self.layers_list[i],
out_feats=self.layers_list[i + 1],
bias=False,
activation=nn.RReLU(),
weight=False,
allow_zero_in_degree=True))
self.reset_parameters()
def reset_parameters(self):
for gcn_weight in self.gcn_weights_list:
init.xavier_uniform_(gcn_weight)
def forward(self, g_list):
feature_list = []
for g in g_list:
feature_list.append(g.ndata['feature'])
for i in range(self.num_layers):
W = self.gcn_weights_list[i]
for j, g in enumerate(g_list):
W = self.recurrent_layers[i](W)
feature_list[j] = self.gnn_convs[i](g, feature_list[j], weight=W)
return self.mlp(feature_list[-1])
class REGCN(nn.Module):
def __init__(self, layers = []):
super(REGCN, self).__init__()
self.layers_list = layers
self.num_layers = len(self.layers_list)-2
self.recurrent_layers = nn.ModuleList() # rnn
self.gnn_convs = nn.ModuleList() # 卷积层
self.gcn_weights_list = nn.ParameterList() # 权重
self.recurrent_layers = nn.ModuleDict() # rnn
self.gcn_weights_list = nn.ModuleDict() # 权重
for i in range(self.num_layers + 1):
if i == 0:
# 卷积层
self.gnn_convs.append(
HeteroGraphConv({
rel: GraphConv(self.layers_list[0][rel],
self.layers_list[i + 1],
allow_zero_in_degree=True,
weight=False,
) for rel in
self.layers_list[0]
}, aggregate='sum'))
self.gcn_weights_list.update({
rel:nn.ParameterList() for rel in self.layers_list[0]
})
self.recurrent_layers.update({
rel:nn.ModuleList() for rel in self.layers_list[0]
})
for rel in self.layers_list[0]:
self.gcn_weights_list[rel].append(
Parameter(
torch.Tensor(self.layers_list[i][rel],
self.layers_list[i + 1])),
)
self.recurrent_layers[rel].append(
MatGRUCell(in_feats=self.layers_list[i][rel],
out_feats=self.layers_list[i + 1]),
)
elif i == self.num_layers:
self.mlp = nn.Sequential(nn.Linear(self.layers_list[i], 50),
nn.ReLU(),
nn.Linear(50, self.layers_list[i+1]),
nn.Softmax())
else:
for rel in self.layers_list[0]:
self.gcn_weights_list[rel].append(
Parameter(torch.Tensor(self.layers_list[i],self.layers_list[i + 1]))
)
self.recurrent_layers[rel].append(
MatGRUCell(in_feats=self.layers_list[i],
out_feats=self.layers_list[i + 1])
)
self.gnn_convs.append(
HeteroGraphConv({
rel: GraphConv(self.layers_list[0],
self.layers_list[i + 1],
allow_zero_in_degree=True,
weight=False,) for rel in
self.layers_list[0]
}, aggregate='sum'))
self.reset_parameters()
def reset_parameters(self):
for rel in self.gcn_weights_list:
for gcn_weight in self.gcn_weights_list[rel]:
init.xavier_uniform_(gcn_weight)
def forward(self, g_list):
feature_list = []
for g in g_list:
feature = { ntypes:g.nodes[ntypes].data["feature"] for ntypes in g_list[0].ntypes }
feature_list.append(feature)
for i in range(self.num_layers):
W= {rel:{"weight":self.gcn_weights_list[rel][i]} for rel in self.layers_list[0]}
for j, g in enumerate(g_list):
W = {rel:{"weight":self.recurrent_layers[rel][i](W[rel]["weight"])}
for rel in self.layers_list[0]}
feature_list[j] = self.gnn_convs[i](g, feature_list[j], mod_kwargs=W)
return {ntypes : self.mlp(feature_list[-1][ntypes]) for ntypes in g_list[0].ntypes}