1.首先先讲一下代码
这是官方给的代码:torch_geometric.nn.conv.transformer_conv — pytorch_geometric documentation
import math
import typing
from typing import Optional, Tuple, Union
import torch
import torch.nn.functional as F
from torch import Tensor
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.nn.dense.linear import Linear
from torch_geometric.typing import (
Adj,
NoneType,
OptTensor,
PairTensor,
SparseTensor,
)
from torch_geometric.utils import softmax
if typing.TYPE_CHECKING:
from typing import overload
else:
from torch.jit import _overload_method as overload
[docs]class TransformerConv(MessagePassing):
r"""The graph transformer operator from the `"Masked Label Prediction:
Unified Message Passing Model for Semi-Supervised Classification"
<https://arxiv.org/abs/2009.03509>`_ paper.
.. math::
\mathbf{x}^{\prime}_i = \mathbf{W}_1 \mathbf{x}_i +
\sum_{j \in \mathcal{N}(i)} \alpha_{i,j} \mathbf{W}_2 \mathbf{x}_{j},
where the attention coefficients :math:`\alpha_{i,j}` are computed via
multi-head dot product attention:
.. math::
\alpha_{i,j} = \textrm{softmax} \left(
\frac{(\mathbf{W}_3\mathbf{x}_i)^{\top} (\mathbf{W}_4\mathbf{x}_j)}
{\sqrt{d}} \right)
Args:
in_channels (int or tuple): Size of each input sample, or :obj:`-1` to
derive the size from the first input(s) to the forward method.
A tuple corresponds to the sizes of source and target
dimensionalities.
out_channels (int): Size of each output sample.
heads (int, optional): Number of multi-head-attentions.
(default: :obj:`1`)
concat (bool, optional): If set to :obj:`False`, the multi-head
attentions are averaged instead of concatenated.
(default: :obj:`True`)
beta (bool, optional): If set, will combine aggregation and
skip information via
.. math::
\mathbf{x}^{\prime}_i = \beta_i \mathbf{W}_1 \mathbf{x}_i +
(1 - \beta_i) \underbrace{\left(\sum_{j \in \mathcal{N}(i)}
\alpha_{i,j} \mathbf{W}_2 \vec{x}_j \right)}_{=\mathbf{m}_i}
with :math:`\beta_i = \textrm{sigmoid}(\mathbf{w}_5^{\top}
[ \mathbf{W}_1 \mathbf{x}_i, \mathbf{m}_i, \mathbf{W}_1
\mathbf{x}_i - \mathbf{m}_i ])` (default: :obj:`False`)
dropout (float, optional): Dropout probability of the normalized
attention coefficients which exposes each node to a stochastically
sampled neighborhood during training. (default: :obj:`0`)
edge_dim (int, optional): Edge feature dimensionality (in case
there are any). Edge features are added to the keys after
linear transformation, that is, prior to computing the
attention dot product. They are also added to final values
after the same linear transformation. The model is:
.. math::
\mathbf{x}^{\prime}_i = \mathbf{W}_1 \mathbf{x}_i +
\sum_{j \in \mathcal{N}(i)} \alpha_{i,j} \left(
\mathbf{W}_2 \mathbf{x}_{j} + \mathbf{W}_6 \mathbf{e}_{ij}
\right),
where the attention coefficients :math:`\alpha_{i,j}` are now
computed via:
.. math::
\alpha_{i,j} = \textrm{softmax} \left(
\frac{(\mathbf{W}_3\mathbf{x}_i)^{\top}
(\mathbf{W}_4\mathbf{x}_j + \mathbf{W}_6 \mathbf{e}_{ij})}
{\sqrt{d}} \right)
(default :obj:`None`)
bias (bool, optional): If set to :obj:`False`, the layer will not learn
an additive bias. (default: :obj:`True`)
root_weight (bool, optional): If set to :obj:`False`, the layer will
not add the transformed root node features to the output and the
option :attr:`beta` is set to :obj:`False`. (default: :obj:`True`)
**kwargs (optional): Additional arguments of
:class:`torch_geometric.nn.conv.MessagePassing`.
"""
_alpha: OptTensor
def __init__(
self,
in_channels: Union[int, Tuple[int, int]],
out_channels: int,
heads: int = 1,
concat: bool = True,
beta: bool = False,
dropout: float = 0.,
edge_dim: Optional[int] = None,
bias: bool = True,
root_weight: bool = True,
**kwargs,
):
kwargs.setdefault('aggr', 'add')
super().__init__(node_dim=0, **kwargs)
self.in_channels = in_channels
self.out_channels = out_channels
self.heads = heads
self.beta = beta and root_weight
self.root_weight = root_weight
self.concat = concat
self.dropout = dropout
self.edge_dim = edge_dim
self._alpha = None
if isinstance(in_channels, int):
in_channels = (in_channels, in_channels)
self.lin_key = Linear(in_channels[0], heads * out_channels)
self.lin_query = Linear(in_channels[1], heads * out_channels)
self.lin_value = Linear(in_channels[0], heads * out_channels)
if edge_dim is not None:
self.lin_edge = Linear(edge_dim, heads * out_channels, bias=False)
else:
self.lin_edge = self.register_parameter('lin_edge', None)
if concat:
self.lin_skip = Linear(in_channels[1], heads * out_channels,
bias=bias)
if self.beta:
self.lin_beta = Linear(3 * heads * out_channels, 1, bias=False)
else:
self.lin_beta = self.register_parameter('lin_beta', None)
else:
self.lin_skip = Linear(in_channels[1], out_channels, bias=bias)
if self.beta:
self.lin_beta = Linear(3 * out_channels, 1, bias=False)
else:
self.lin_beta = self.register_parameter('lin_beta', None)
self.reset_parameters()
[docs] def reset_parameters(self):
super().reset_parameters()
self.lin_key.reset_parameters()
self.lin_query.reset_parameters()
self.lin_value.reset_parameters()
if self.edge_dim:
self.lin_edge.reset_parameters()
self.lin_skip.reset_parameters()
if self.beta:
self.lin_beta.reset_parameters()
@overload
def forward(
self,
x: Union[Tensor, PairTensor],
edge_index: Adj,
edge_attr: OptTensor = None,
return_attention_weights: NoneType = None,
) -> Tensor:
pass
@overload
def forward( # noqa: F811
self,
x: Union[Tensor, PairTensor],
edge_index: Tensor,
edge_attr: OptTensor = None,
return_attention_weights: bool = None,
) -> Tuple[Tensor, Tuple[Tensor, Tensor]]:
pass
@overload
def forward( # noqa: F811
self,
x: Union[Tensor, PairTensor],
edge_index: SparseTensor,
edge_attr: OptTensor = None,
return_attention_weights: bool = None,
) -> Tuple[Tensor, SparseTensor]:
pass
[docs] def forward( # noqa: F811
self,
x: Union[Tensor, PairTensor],
edge_index: Adj,
edge_attr: OptTensor = None,
return_attention_weights: Optional[bool] = None,
) -> Union[
Tensor,
Tuple[Tensor, Tuple[Tensor, Tensor]],
Tuple[Tensor, SparseTensor],
]:
r"""Runs the forward pass of the module.
Args:
x (torch.Tensor or (torch.Tensor, torch.Tensor)): The input node
features.
edge_index (torch.Tensor or SparseTensor): The edge indices.
edge_attr (torch.Tensor, optional): The edge features.
(default: :obj:`None`)
return_attention_weights (bool, optional): If set to :obj:`True`,
will additionally return the tuple
:obj:`(edge_index, attention_weights)`, holding the computed
attention weights for each edge. (default: :obj:`None`)
"""
H, C = self.heads, self.out_channels
if isinstance(x, Tensor):
x = (x, x)
query = self.lin_query(x[1]).view(-1, H, C)
key = self.lin_key(x[0]).view(-1, H, C)
value = self.lin_value(x[0]).view(-1, H, C)
# propagate_type: (query: Tensor, key:Tensor, value: Tensor,
# edge_attr: OptTensor)
out = self.propagate(edge_index, query=query, key=key, value=value,
edge_attr=edge_attr)
alpha = self._alpha
self._alpha = None
if self.concat:
out = out.view(-1, self.heads * self.out_channels)
else:
out = out.mean(dim=1)
if self.root_weight:
x_r = self.lin_skip(x[1])
if self.lin_beta is not None:
beta = self.lin_beta(torch.cat([out, x_r, out - x_r], dim=-1))
beta = beta.sigmoid()
out = beta * x_r + (1 - beta) * out
else:
out = out + x_r
if isinstance(return_attention_weights, bool):
assert alpha is not None
if isinstance(edge_index, Tensor):
return out, (edge_index, alpha)
elif isinstance(edge_index, SparseTensor):
return out, edge_index.set_value(alpha, layout='coo')
else:
return out
def message(self, query_i: Tensor, key_j: Tensor, value_j: Tensor,
edge_attr: OptTensor, index: Tensor, ptr: OptTensor,
size_i: Optional[int]) -> Tensor:
if self.lin_edge is not None:
assert edge_attr is not None
edge_attr = self.lin_edge(edge_attr).view(-1, self.heads,
self.out_channels)
key_j = key_j + edge_attr
alpha = (query_i * key_j).sum(dim=-1) / math.sqrt(self.out_channels)
alpha = softmax(alpha, index, ptr, size_i)
self._alpha = alpha
alpha = F.dropout(alpha, p=self.dropout, training=self.training)
out = value_j
if edge_attr is not None:
out = out + edge_attr
out = out * alpha.view(-1, self.heads, 1)
return out
def __repr__(self) -> str:
return (f'{self.__class__.__name__}({self.in_channels}, '
f'{self.out_channels}, heads={self.heads})')
2.详细解释一下
几个重要的参数
in_channels (int or tuple): Size of each input sample, or :obj:`-1` to derive the size from the first input(s) to the forward method. A tuple corresponds to the sizes of source and target dimensionalities.
out_channels (int): Size of each output sample.
heads (int, optional): Number of multi-head-attentions. (default: :obj:`1`)
怎么理解这几个参数?感觉之前自己学习看东西不细,今天又来重复学习了。还是得看英文原版,学习是一点不能走捷径。
class TransformerConv(in_channels:Union[int,Tuple[int,int]],out_channels:int,heads:int = 1,concat:boll=True,beta:bool=False,dropout:float=0.0,edge_dim:Optimal[int]=None,bias:bool=True,root_weight:bool=True,**kwargs)
in_channels
表示每个输入样本的大小。如果设置为整数,则表示所有输入样本的大小相同;如果设置为-1
,则表示输入样本的大小将从forward
方法的第一个输入中推导出来;如果设置为元组,则表示输入样本的大小对应于源维度和目标维度的大小。其中in_channels:Union[int,Tuple[int,int]]。就是说in_channels后面的参数是可以是int或者是元组的形式。
整数 (int):当
in_channels
是单个整数时,它表示所有输入节点特征的维度相同。例如,如果in_channels=16
,则每个节点的输入特征维度都是 16。整数元组 (Tuple[int, int]):当
in_channels
是一个整数元组时,它表示源节点和目标节点的输入特征维度分别不同。例如,in_channels=(16, 32)
表示源节点的输入特征维度是 16,而目标节点的输入特征维度是 32。这种情况通常用于异构图或双向图神经网络中。import torch from torch_geometric.nn import TransformerConv # 节点特征矩阵 x = torch.randn((10, 16)) # 10 个节点,每个节点 16 维特征 # 边索引矩阵 edge_index = torch.tensor([[0, 1, 2, 3], [1, 2, 3, 0]], dtype=torch.long) # 创建 TransformerConv 层(单一输入维度) conv = TransformerConv(in_channels=16, out_channels=32, heads=1, dropout=0.1) # 应用 TransformerConv 层 out = conv(x, edge_index) print(out.shape) # 输出特征矩阵,形状为 (10, 32)
这个
forward
方法是图神经网络中用来进行前向传播(forward pass)的关键方法。它接收输入特征和图的结构信息,并计算输出特征。参数解释
x (Union[Tensor, PairTensor]):
- 输入节点的特征。可以是单个张量,也可以是一个包含两个张量的元组(PairTensor)。
- 如果是单个张量,形状通常为 [num_nodes,in_channels][num\_nodes, in\_channels][num_nodes,in_channels],表示节点数量和每个节点的特征维度。
- 如果是 PairTensor,意味着我们可能处理的是异构图,源节点和目标节点的特征可以不同。
edge_index (Adj):
- 图的边索引,表示节点之间的连接关系。通常是形状为 [2,num_edges][2, num\_edges][2,num_edges] 的张量,表示每条边的起点和终点。
Adj
类型可以是稀疏矩阵(SparseTensor)或者边索引矩阵(edge_index)。edge_attr (OptTensor):
- 边的特征,可选参数。如果存在,则形状通常为 [num_edges,edge_feature_dim][num\_edges, edge\_feature\_dim][num_edges,edge_feature_dim],表示每条边的特征。
返回值
返回值的类型是一个联合类型(Union),具体取决于
return_attention_weights
的值:Tensor:
- 如果
return_attention_weights
为False
,则返回更新后的节点特征。形状通常为 [num_nodes,out_channels][num\_nodes, out\_channels][num_nodes,out_channels]。Tuple[Tensor, Tuple[Tensor, Tensor]]:
- 如果
return_attention_weights
为True
并且使用的是普通的边索引矩阵,则返回一个包含两个元素的元组:
- 第一个元素是更新后的节点特征。
- 第二个元素是包含两个张量的元组,表示注意力权重及其对应的边索引。
Tuple[Tensor, SparseTensor]:
- 如果
return_attention_weights
为True
并且使用的是稀疏矩阵,则返回一个包含两个元素的元组:
- 第一个元素是更新后的节点特征。
- 第二个元素是稀疏矩阵形式的注意力权重。
return_attention_weights (Optional[bool]):
- 是否返回注意力权重(attention weights)。如果设置为
True
,前向传播会返回注意力权重。
out_channels
表示每个输出样本的大小,即经过卷积操作后产生的特征向量的维度大小。
当使用
tg.nn.TransformerConv
时,可以通过以下方式理解in_channels
和out_channels
:假设我们有一个图数据集,每个节点都有一个 10 维的特征向量表示。那么在这种情况下:
如果我们想将每个节点的特征向量作为输入,然后使用
tg.nn.TransformerConv
进行卷积操作,那么in_channels
应该设置为 10,表示每个输入样本的大小为 10。假设我们想将节点的特征向量转换为一个 16 维的特征向量,那么
out_channels
应该设置为 16,表示每个输出样本的大小为 16,即经过卷积操作后每个节点的特征向量将变为 16 维。在
tg.nn.TransformerConv
中,heads
参数表示多头注意力的数量。举个例子,如果heads
参数设置为 4,那么模型将学习 4 组注意力权重,每组权重都用于计算输入的不同子空间的注意力,然后将这些头的输出进行合并以产生最终的输出。
举个整体的例子:
我们有一个输入张量
x
,它的形状是(batch_size, seq_length, input_dim)
,其中:
batch_size
表示批量大小;seq_length
表示序列长度;input_dim
表示输入特征的维度。现在假设我们使用了
tg.nn.TransformerConv
,并设置heads=2
,那么模型将学习两组注意力权重,每组用于计算不同的注意力。输出张量的形状将取决于out_channels
参数,我们假设out_channels=64
。
import torch
import torch_geometric.nn as tg
# 假设输入张量的形状是 (batch_size, seq_length, input_dim)
x = torch.randn(32, 10, 128) # 32 个样本,每个样本有 10 个时间步,每个时间步有 128 个特征
# 创建 TransformerConv 模型,设置 heads=2,out_channels=64
conv_layer = tg.nn.TransformerConv(in_channels=128, out_channels=64, heads=2)
# 使用模型进行前向传播
output = conv_layer(x)
print("输出张量的形状:", output.shape)
2.1将特征映射到键值对中
在这里,通过线性变换层 Linear
,输入特征被转换成了键(key)、查询(query)和数值(value)的表示形式,以便用于多头自注意力机制。
具体来说:
self.lin_key
用于将输入特征(in_channels[0])映射到键的表示形式。self.lin_query
用于将输入特征(in_channels[1])映射到查询的表示形式。self.lin_value
用于将输入特征(in_channels[0])映射到数值的表示形式。
具体地,假设输入特征的维度是 (batch_size, num_nodes, in_channels)
,其中 batch_size
是批量大小,num_nodes
是节点数,in_channels
是输入特征的通道数。在映射到键的过程中,线性变换层的权重矩阵将是一个维度为 (in_channels, heads * out_channels)
的矩阵,其中 heads
是注意力头的数量,out_channels
是输出特征的通道数。因此,通过矩阵乘法运算,输入特征将被映射到一个新的特征空间,其维度为 (batch_size, num_nodes, heads, out_channels)
。在这个新的特征空间中,每个节点的每个头都有一个键表示。
3.里面的数学表达式意义
具体的卷积操作
-
特征变换: 首先,对输入特征进行线性变换。这一步的作用是将输入的节点特征映射到一个新的特征空间。
-
计算注意力权重: 使用注意力机制来计算节点之间的注意力权重。这个过程包括以下步骤:
- 对每个节点的特征向量进行变换,得到查询向量(query)、键向量(key)和值向量(value)。
- 计算查询向量和键向量的点积,并除以缩放因子(通常是键向量维度的平方根),然后应用 softmax 函数得到注意力权重。
-
消息传递: 根据计算得到的注意力权重,将邻居节点的特征加权求和,形成新的节点表示。
-
特征更新: 将消息传递得到的新节点表示进行线性变换,并根据需要进行拼接或求平均操作。
-
具体的这个W是怎么学到的呢?
-
输入的特征维度不一定和输出的特征维度一样。该步骤是作为一个特征提取的过程
-
我们看一下论文中一般怎么定义这个过程吧!
为了聚合节点信息