attention层的实现
重点是理解_prepare_attentional_mechanism_input函数的作用:就是实现了任意两个节点之间进行拼接。也就是对应原文中的这一部分公式:
class GraphAttentionLayer(nn.Module):
"""
Simple GAT layer, similar to https://arxiv.org/abs/1710.10903
"""
def __init__(self, in_features, out_features, dropout, alpha, concat=True):
super(GraphAttentionLayer, self).__init__()
self.dropout = dropout
self.in_features = in_features
self.out_features = out_features
# 学习因子
self.alpha = alpha
self.concat = concat
# 建立都是0的矩阵,大小为(输入维度,输出维度)
self.W = nn.Parameter(torch.empty(size=(in_features, out_features)))
# xavier初始化
nn.init.xavier_uniform_(self.W.data, gain=1.414)
# 这里的self.a,对应的是论文里的向量a,故其维度大小应该为(2*out_features, 1)
self.a = nn.Parameter(torch.empty(size=(2*out_features, 1)))
nn.init.xavier_uniform_(self.