图注意力网络(Graph Attention Network, GAT)的卷积层

 用于图神经网络(Graph Neural Networks, GNNs)。GAT通过引入自适应的注意力机制,使每个节点根据邻居节点的重要性来聚合信息。这个实现基于PyTorch和PyTorch Geometric库。代码的主要功能如下:

  • 构造函数 (__init__):

    • 初始化GAT层的各项参数,包括输入输出通道数、头数、是否使用残差连接等。
    • 初始化网络中使用的线性层和注意力权重参数。
  • 重置参数 (reset_parameters):

    • 对模型中的参数进行初始化,使用了glorot初始化方法和零初始化方法。
  • 前向传播 (forward):

    • 执行GAT的前向传播计算,包括节点特征的线性变换、注意力系数的计算、邻居节点特征的聚合以及输出的生成。
    • 支持双输入特征(源节点和目标节点)以及选择性地返回注意力权重。
  • 边的更新 (edge_update):

    • 计算边的注意力权重,结合源节点和目标节点的特征,应用LeakyReLU、softmax和dropout操作。
  • 消息传递 (message):

    • 计算最终的消息传递操作,将注意力权重应用到邻居节点的特征上。

该卷积层用到的数学公式如下:

 

参数(Args):

  • in_channels (int 或 tuple):每个输入样本的大小,或者设为 -1forward 方法的第一个输入中推导大小。元组形式用于二分图的源和目标维度。
  • out_channels (int):每个输出样本的大小。
  • heads (int, 可选):多头注意力的数量(默认为 1)。
  • concat (bool, 可选):如果设为 False,多头注意力将被平均而不是拼接(默认为 True)。
  • negative_slope (float, 可选):LeakyReLU 的负斜率角度(默认为 0.2)。
  • dropout (float, 可选):归一化的注意力系数的 Dropout 概率,这使每个节点在训练期间暴露于随机采样的邻域(默认为 0)。
  • add_self_loops (bool, 可选):如果设为 False,将不会在输入图中添加自环(默认为 True)。
  • edge_dim (int, 可选):边特征的维度(如果有的话)(默认为 None)。
  • fill_value (float 或 torch.Tensor 或 str, 可选):生成自环边特征的方式(如果 edge_dim != None)。如果是 floattorch.Tensor,自环边特征将直接由 fill_value 给出。如果是 str,自环边特征通过聚合指向特定节点的所有边的特征来计算,根据减少操作("add"、"mean"、"min"、"max"、"mul")(默认为 "mean")。
  • bias (bool, 可选):如果设为 False,该层将不学习可加的偏置(默认为 True)。
  • residual (bool, 可选):如果设为 True,该层将添加一个可学习的跳跃连接(默认为 False)。
  • kwargs (可选):额外的参数传递给 torch_geometric.nn.conv.MessagePassing

 数据形状如下:

 对于forward函数,运行模块的前向传播(forward pass)。

参数(Args):

  • x (torch.Tensor 或 (torch.Tensor, torch.Tensor)):输入的节点特征。
  • edge_index (torch.Tensor 或 SparseTensor):边的索引。
  • edge_attr (torch.Tensor, 可选):边的特征。(默认为 None
  • size ((int, int), 可选):邻接矩阵的形状。(默认为 None
  • return_attention_weights (bool, 可选):如果设为 True,将额外返回一个元组 (edge_index, attention_weights),包含每条边计算得到的注意力权重。(默认为 None
import torch
from torch import Tensor
from torch.nn import Parameter
from torch_geometric.typing import Adj, OptTensor, SparseTensor
from torch_geometric.nn.inits import glorot, zeros
from torch_geometric.utils import remove_self_loops, add_self_loops
from torch_geometric.nn.conv import MessagePassing

class GATConv(MessagePassing):
    def __init__(self, in_channels, out_channels, heads=1, concat=True, negative_slope=0.2,
                 dropout=0., add_self_loops=True, bias=True, **kwargs):
        super(GATConv, self).__init__(aggr='add', node_dim=0, **kwargs)
        
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.heads = heads
        self.concat = concat
        self.negative_slope = negative_slope
        self.dropout = dropout
        self.add_self_loops = add_self_loops
        
        # Define the learnable parameters.
        self.lin = torch.nn.Linear(in_channels, heads * out_channels, bias=False)
        self.att = Parameter(torch.Tensor(1, heads, 2 * out_channels))
        
        if bias and concat:
            self.bias = Parameter(torch.Tensor(heads * out_channels))
        elif bias:
            self.bias = Parameter(torch.Tensor(out_channels))
        else:
            self.register_parameter('bias', None)
        
        # Reset parameters using glorot and zeros initialization.
        self.reset_parameters()
    
    def reset_parameters(self):
        glorot(self.lin.weight)
        glorot(self.att)
        zeros(self.bias)
    
    def forward(self, x: Tensor, edge_index: Adj, return_attention_weights=False):
        # Add self-loops to the adjacency matrix.
        if self.add_self_loops:
            edge_index, _ = remove_self_loops(edge_index)
            edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))

        # Apply linear transformation.
        x = self.lin(x)
        
        # Start propagating messages.
        out = self.propagate(edge_index, x=(x, x), return_attention_weights=return_attention_weights)
        
        # If heads are concatenated, reshape the output accordingly.
        if self.concat:
            out = out.view(-1, self.heads * self.out_channels)
        else:
            out = out.mean(dim=1)
        
        if self.bias is not None:
            out += self.bias
        
        if isinstance(out, tuple):
            return out[0], out[1]
        else:
            return out
    
    def edge_update(self, edge_index: Tensor, x_i: Tensor, x_j: Tensor) -> Tensor:
        # Compute attention scores.
        x_i = x_i.view(-1, self.heads, self.out_channels)
        x_j = x_j.view(-1, self.heads, self.out_channels)
        alpha = (torch.cat([x_i, x_j], dim=-1) * self.att).sum(dim=-1)
        alpha = torch.nn.functional.leaky_relu(alpha, self.negative_slope)
        alpha = torch.nn.functional.softmax(alpha, edge_index[0])
        alpha = torch.nn.functional.dropout(alpha, p=self.dropout, training=self.training)
        return alpha
    
    def message(self, edge_index: Tensor, x_j: Tensor, alpha: Tensor) -> Tensor:
        # Compute the final message as attention-weighted node features.
        return alpha.unsqueeze(-1) * x_j
    
    def message_and_aggregate(self, adj_t: SparseTensor, x: Tensor) -> Tensor:
        # Specialized case for sparse tensors.
        return torch_sparse.matmul(adj_t, x, reduce=self.aggr)
    
    def __repr__(self):
        return f'{self.__class__.__name__}({self.in_channels}, {self.out_channels}, heads={self.heads})'

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值