接上篇学习笔记GAT学习:PyG实现GAT(图注意力神经网络)网络(一)为了使得Attention的效果更好,所以加入multi-head attention。画个图说明multi-head attention的工作原理。
其实就相当于并联了head_num个attention后,将每个attention层的输出特征拼接起来,然后再输入一个attenion层得到输出结果。
预备知识
关于GAT的原理等知识,参考我的上篇博客:PyG实现GAT(图注意力神经网络)网络(一)
代码分析
import torch
import math
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops,degree
from torch_geometric.datasets import Planetoid
import ssl
import torch.nn.functional as F
class GAL(MessagePassing):
def __init__(self,in_features,out_featrues):
super(GAL,self).__init__(aggr='add')
self.a = torch.nn.Parameter(torch.zeros(size=(2*out_featrues, 1)))
torch.nn.init.xavier_uniform_(self.a.data, gain=1.414) # 初始化
# 定义leakyrelu激活函数
self.leakyrelu = torch.nn.LeakyReLU()
self.linear=torch.nn.Linear(in_features,out_featrues)
def forward(self,x,edge_index):
x=self.linear(x)
N=x.size()[0]
row,col=edge_index
a_input = torch.cat([x[row], x[col]], dim=1)
# [N, N, 1] => [N, N] 图注意力的相关系数(未归一化)
temp=torch.mm(a_input,self.a).squeeze()
e