用于图神经网络(Graph Neural Networks, GNNs)。GAT通过引入自适应的注意力机制,使每个节点根据邻居节点的重要性来聚合信息。这个实现基于PyTorch和PyTorch Geometric库。代码的主要功能如下:
-
构造函数 (
__init__
):- 初始化GAT层的各项参数,包括输入输出通道数、头数、是否使用残差连接等。
- 初始化网络中使用的线性层和注意力权重参数。
-
重置参数 (
reset_parameters
):- 对模型中的参数进行初始化,使用了
glorot
初始化方法和零初始化方法。
- 对模型中的参数进行初始化,使用了
-
前向传播 (
forward
):- 执行GAT的前向传播计算,包括节点特征的线性变换、注意力系数的计算、邻居节点特征的聚合以及输出的生成。
- 支持双输入特征(源节点和目标节点)以及选择性地返回注意力权重。
-
边的更新 (
edge_update
):- 计算边的注意力权重,结合源节点和目标节点的特征,应用LeakyReLU、softmax和dropout操作。
-
消息传递 (
message
):- 计算最终的消息传递操作,将注意力权重应用到邻居节点的特征上。
该卷积层用到的数学公式如下:
参数(Args):
- in_channels (int 或 tuple):每个输入样本的大小,或者设为
-1
从forward
方法的第一个输入中推导大小。元组形式用于二分图的源和目标维度。 - out_channels (int):每个输出样本的大小。
- heads (int, 可选):多头注意力的数量(默认为 1)。
- concat (bool, 可选):如果设为
False
,多头注意力将被平均而不是拼接(默认为True
)。 - negative_slope (float, 可选):LeakyReLU 的负斜率角度(默认为 0.2)。
- dropout (float, 可选):归一化的注意力系数的 Dropout 概率,这使每个节点在训练期间暴露于随机采样的邻域(默认为 0)。
- add_self_loops (bool, 可选):如果设为
False
,将不会在输入图中添加自环(默认为True
)。 - edge_dim (int, 可选):边特征的维度(如果有的话)(默认为
None
)。 - fill_value (float 或 torch.Tensor 或 str, 可选):生成自环边特征的方式(如果
edge_dim != None
)。如果是float
或torch.Tensor
,自环边特征将直接由fill_value
给出。如果是str
,自环边特征通过聚合指向特定节点的所有边的特征来计算,根据减少操作("add"、"mean"、"min"、"max"、"mul")(默认为 "mean")。 - bias (bool, 可选):如果设为
False
,该层将不学习可加的偏置(默认为True
)。 - residual (bool, 可选):如果设为
True
,该层将添加一个可学习的跳跃连接(默认为False
)。 - kwargs (可选):额外的参数传递给
torch_geometric.nn.conv.MessagePassing
。
数据形状如下:
对于forward函数,运行模块的前向传播(forward pass)。
参数(Args):
- x (torch.Tensor 或 (torch.Tensor, torch.Tensor)):输入的节点特征。
- edge_index (torch.Tensor 或 SparseTensor):边的索引。
- edge_attr (torch.Tensor, 可选):边的特征。(默认为
None
) - size ((int, int), 可选):邻接矩阵的形状。(默认为
None
) - return_attention_weights (bool, 可选):如果设为
True
,将额外返回一个元组(edge_index, attention_weights)
,包含每条边计算得到的注意力权重。(默认为None
)
import torch
from torch import Tensor
from torch.nn import Parameter
from torch_geometric.typing import Adj, OptTensor, SparseTensor
from torch_geometric.nn.inits import glorot, zeros
from torch_geometric.utils import remove_self_loops, add_self_loops
from torch_geometric.nn.conv import MessagePassing
class GATConv(MessagePassing):
def __init__(self, in_channels, out_channels, heads=1, concat=True, negative_slope=0.2,
dropout=0., add_self_loops=True, bias=True, **kwargs):
super(GATConv, self).__init__(aggr='add', node_dim=0, **kwargs)
self.in_channels = in_channels
self.out_channels = out_channels
self.heads = heads
self.concat = concat
self.negative_slope = negative_slope
self.dropout = dropout
self.add_self_loops = add_self_loops
# Define the learnable parameters.
self.lin = torch.nn.Linear(in_channels, heads * out_channels, bias=False)
self.att = Parameter(torch.Tensor(1, heads, 2 * out_channels))
if bias and concat:
self.bias = Parameter(torch.Tensor(heads * out_channels))
elif bias:
self.bias = Parameter(torch.Tensor(out_channels))
else:
self.register_parameter('bias', None)
# Reset parameters using glorot and zeros initialization.
self.reset_parameters()
def reset_parameters(self):
glorot(self.lin.weight)
glorot(self.att)
zeros(self.bias)
def forward(self, x: Tensor, edge_index: Adj, return_attention_weights=False):
# Add self-loops to the adjacency matrix.
if self.add_self_loops:
edge_index, _ = remove_self_loops(edge_index)
edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))
# Apply linear transformation.
x = self.lin(x)
# Start propagating messages.
out = self.propagate(edge_index, x=(x, x), return_attention_weights=return_attention_weights)
# If heads are concatenated, reshape the output accordingly.
if self.concat:
out = out.view(-1, self.heads * self.out_channels)
else:
out = out.mean(dim=1)
if self.bias is not None:
out += self.bias
if isinstance(out, tuple):
return out[0], out[1]
else:
return out
def edge_update(self, edge_index: Tensor, x_i: Tensor, x_j: Tensor) -> Tensor:
# Compute attention scores.
x_i = x_i.view(-1, self.heads, self.out_channels)
x_j = x_j.view(-1, self.heads, self.out_channels)
alpha = (torch.cat([x_i, x_j], dim=-1) * self.att).sum(dim=-1)
alpha = torch.nn.functional.leaky_relu(alpha, self.negative_slope)
alpha = torch.nn.functional.softmax(alpha, edge_index[0])
alpha = torch.nn.functional.dropout(alpha, p=self.dropout, training=self.training)
return alpha
def message(self, edge_index: Tensor, x_j: Tensor, alpha: Tensor) -> Tensor:
# Compute the final message as attention-weighted node features.
return alpha.unsqueeze(-1) * x_j
def message_and_aggregate(self, adj_t: SparseTensor, x: Tensor) -> Tensor:
# Specialized case for sparse tensors.
return torch_sparse.matmul(adj_t, x, reduce=self.aggr)
def __repr__(self):
return f'{self.__class__.__name__}({self.in_channels}, {self.out_channels}, heads={self.heads})'