实现带有batch的GAT
作者实现的是不带batch的,踩了一点坑,自己改了一下带batch的
from torch import nn
import torch
import torch.nn.functional as F
class nconv(nn.Module):
def __init__(self):
super(nconv, self).__init__()
def forward(self, x, A):
x = torch.einsum('ncvl,vw->ncwl', (x, A))
return x.contiguous()
class GraphAttentionLayer(nn.Module):
"""
Simple GAT layer, similar to https://arxiv.org/abs/1710.10903
图注意力层
input: (B,N,C_in)
output: (B,N,C_out)
"""
def __init__(self, in_features, out_features, dropout, alpha, concat=True):
super(GraphAttentionLayer, self).__init__()
self.in_features = in_features # 节点表示向量的输入特征数
self.out_features = out_features # 节点表示向量的输出特征数
self.dropout = dropout # dropout参数
self.alpha = alpha # leakyrelu激活的参数
self.concat = concat # 如果为true, 再进行elu激活
# 定义可训练参数,即论文中的W和a
self.W = nn.Parameter(torch.zeros(size=(in_features, out_features)))
nn.init.xavier_uniform_(self.W.data, gain=1.414)