PyG实现自定义GAT层
本系列中的第三篇介绍了如何调用pyg封装好的GAT函数,当然同样的,我们需要学会如何自定义网络层以满足研究需求。
完整代码
import torch
import math
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops,remove_self_loops,softmax
from torch_geometric.datasets import Planetoid
import ssl
import torch.nn.functional as F
class GATConv(MessagePassing):
def __init__(self, in_channels,out_channels, heads: int = 1, concat: bool = True,
negative_slope: float = 0.2, dropout: float = 0.,
add_self_loops: bool = True, bias: bool = True, **kwargs):
kwargs.setdefault('aggr', 'add')
super(GATConv, self).__init__(node_dim=0, **kwargs)
#in_channel&out channel就是我们的输入输出数
self.in_channels = in_channels
self.out_channels = out_channels
#head即设置几个attention头
self.heads = heads
#concat用于设置是否拼接attention的输出
self.concat = concat
#negative_slope设置leaklyRelu的参数
self.negative_slope = negative_slope
self.dropout = dropout
#add_self_loops设置是否添加自环
self.add_self_loops = add_self_loops
#这里将特征映射到每个attention头所需要的特征数,从而满足每个attention头的输入
self.lin = Linear(