官网链接:cs244
colab4里需要自己实现的部分是单层GAT,下面就是GAT单层实现部分:
参考pyG的具体实现即可:https://pytorchgeometric.readthedocs.io/en/latest/_modules/torch_geometric/nn/conv/gat_conv.html#GATConv
MessagePassing类的各种函数定义就不放了,colab里有。下面是相关公式和代码
公式:
GAT可以带权重的聚合节点的邻居信息,总公式为公式(1):
αij为注意力权重:公式(2):
上面的两个公式就是下面要实现的过程(采用Transtive方式进行训练)
(下面是按照无向图实现的)
实现逻辑:
1、对每个节点应用线性转换(计算Wh)
将维度为[N,F]的特征矩阵,转换为[N,F’]的特征矩阵
forward函数中:
w_l = self.lin_l(x).view(-1, H, C)#[N,H,C]
w_r = self.lin_r(x).view(-1, H, C)#[N,H,C]
2、计算注意力(在message函数中计算):按照公式(2)逐步计算
对于一组点(i,j),首先与要计算alpha_l和alpha_r,即:
alpha_l = (w_l * self.att_l).sum(-1) #alpha_l的维度:[N,H]
alpha_r = (w_r * self.att_r).sum(-1)#[N,H]
之后只用LeakRelu:
message函数:
alpha = F.leaky_relu(alpha, self.negative_slope)
alpha = softmax(alpha, index, ptr, size_i)
alpha = F.dropout(alpha, p=self.dropout, training=self.training)
#此时alpha维度为[E,H]
计算完成后的alpha就是每条边(i,j)的注意力
3、在message过程中乘上加权。
这里的E是输入图的边数,因为是Transtive的训练方式,因此E是整个图的边数。
message函数:
alpha = alpha.unsqueeze(-1) #alpha = [E,H,1](扩展最后一维)
out = x_j * alpha #[N,H,C] * [E,H,C](逐元素相乘,乘的时候,内部的[E.H,1]会被广播为[E,H,C],有疑问可查‘’pytorch中的广播机制’)
#out=[E,H,C]
4、聚合:对邻接节点的嵌入表示进行求和
单层GAT代码:
class GAT(MessagePassing):
def __init__(self, in_channels, out_channels, heads = 2,
negative_slope = 0.2, dropout = 0., **kwargs):
super(GAT, self).__init__(node_dim=0, **kwargs)
self.in_channels = in_channels
self.out_channels = out_channels
self.heads = heads
self.negative_slope = negative_slope
self.dropout = dropout
self.lin_l = None
self.lin_r = None
self.att_l = None
self.att_r = None
##################定义线性层,即两个线性层和两个权重向量#####
# TODO: Your code here!
self.lin_l = nn.Linear(self.in_channels, self.out_channels * self.heads)
self.lin_r = self.lin_l
self.att_r = Parameter(torch.Tensor(1, self.heads, self.out_channels))
self.att_l = Parameter(torch.Tensor(1, self.heads, self.out_channels))
############################################################################
############################################################################
self.reset_parameters()
def reset_parameters(self):
nn.init.xavier_uniform_(self.lin_l.weight)
nn.init.xavier_uniform_(self.lin_r.weight)
nn.init.xavier_uniform_(self.att_l)
nn.init.xavier_uniform_(self.att_r)
def forward(self, x, edge_index, size = None):
H, C = self.heads, self.out_channels
############################################################################
# TODO: Your code here!
w_l = self.lin_l(x).view(-1, H, C)
w_r = self.lin_r(x).view(-1, H, C)
# print(w_r.shape)
alpha_l = (w_l * self.att_l).sum(-1)
alpha_r = (w_r * self.att_r).sum(-1)
out = self.propagate(edge_index,x = (w_l, w_r), alpha = (alpha_l, alpha_r), dim_size=size).view(-1, H * C)
############################################################################
return out
def message(self, x_j, alpha_j, alpha_i, index, ptr, size_i):
############################################################################
# TODO: Your code here!
# alpha = [E,H]
# x = [N,H,C]
H, C = self.heads, self.out_channels
alpha = alpha_i + alpha_j
alpha = F.leaky_relu(alpha, self.negative_slope)
alpha = softmax(alpha, index, ptr, size_i)
alpha = F.dropout(alpha, p=self.dropout, training=self.training)
alpha = alpha.unsqueeze(-1) #alpha = [E,H,1]
#alpha [E, H, C]
out = x_j * alpha #[N,H,C] * [E,H,C](alpha进行广播)
############################################################################
return out
def aggregate(self, inputs, index, dim_size = None):
############################################################################
# TODO: Your code here!
out = torch_scatter.scatter(inputs, index, reduce='sum', dim=0)
############################################################################
return out