详细讲一下PYG 里面的torch_geometric.nn.conv.transformer_conv函数

1.首先先讲一下代码

这是官方给的代码:torch_geometric.nn.conv.transformer_conv — pytorch_geometric documentation

import math
import typing
from typing import Optional, Tuple, Union

import torch
import torch.nn.functional as F
from torch import Tensor

from torch_geometric.nn.conv import MessagePassing
from torch_geometric.nn.dense.linear import Linear
from torch_geometric.typing import (
    Adj,
    NoneType,
    OptTensor,
    PairTensor,
    SparseTensor,
)
from torch_geometric.utils import softmax

if typing.TYPE_CHECKING:
    from typing import overload
else:
    from torch.jit import _overload_method as overload


[docs]class TransformerConv(MessagePassing):
    r"""The graph transformer operator from the `"Masked Label Prediction:
    Unified Message Passing Model for Semi-Supervised Classification"
    <https://arxiv.org/abs/2009.03509>`_ paper.

    .. math::
        \mathbf{x}^{\prime}_i = \mathbf{W}_1 \mathbf{x}_i +
        \sum_{j \in \mathcal{N}(i)} \alpha_{i,j} \mathbf{W}_2 \mathbf{x}_{j},

    where the attention coefficients :math:`\alpha_{i,j}` are computed via
    multi-head dot product attention:

    .. math::
        \alpha_{i,j} = \textrm{softmax} \left(
        \frac{(\mathbf{W}_3\mathbf{x}_i)^{\top} (\mathbf{W}_4\mathbf{x}_j)}
        {\sqrt{d}} \right)

    Args:
        in_channels (int or tuple): Size of each input sample, or :obj:`-1` to
            derive the size from the first input(s) to the forward method.
            A tuple corresponds to the sizes of source and target
            dimensionalities.
        out_channels (int): Size of each output sample.
        heads (int, optional): Number of multi-head-attentions.
            (default: :obj:`1`)
        concat (bool, optional): If set to :obj:`False`, the multi-head
            attentions are averaged instead of concatenated.
            (default: :obj:`True`)
        beta (bool, optional): If set, will combine aggregation and
            skip information via

            .. math::
                \mathbf{x}^{\prime}_i = \beta_i \mathbf{W}_1 \mathbf{x}_i +
                (1 - \beta_i) \underbrace{\left(\sum_{j \in \mathcal{N}(i)}
                \alpha_{i,j} \mathbf{W}_2 \vec{x}_j \right)}_{=\mathbf{m}_i}

            with :math:`\beta_i = \textrm{sigmoid}(\mathbf{w}_5^{\top}
            [ \mathbf{W}_1 \mathbf{x}_i, \mathbf{m}_i, \mathbf{W}_1
            \mathbf{x}_i - \mathbf{m}_i ])` (default: :obj:`False`)
        dropout (float, optional): Dropout probability of the normalized
            attention coefficients which exposes each node to a stochastically
            sampled neighborhood during training. (default: :obj:`0`)
        edge_dim (int, optional): Edge feature dimensionality (in case
            there are any). Edge features are added to the keys after
            linear transformation, that is, prior to computing the
            attention dot product. They are also added to final values
            after the same linear transformation. The model is:

            .. math::
                \mathbf{x}^{\prime}_i = \mathbf{W}_1 \mathbf{x}_i +
                \sum_{j \in \mathcal{N}(i)} \alpha_{i,j} \left(
                \mathbf{W}_2 \mathbf{x}_{j} + \mathbf{W}_6 \mathbf{e}_{ij}
                \right),

            where the attention coefficients :math:`\alpha_{i,j}` are now
            computed via:

            .. math::
                \alpha_{i,j} = \textrm{softmax} \left(
                \frac{(\mathbf{W}_3\mathbf{x}_i)^{\top}
                (\mathbf{W}_4\mathbf{x}_j + \mathbf{W}_6 \mathbf{e}_{ij})}
                {\sqrt{d}} \right)

            (default :obj:`None`)
        bias (bool, optional): If set to :obj:`False`, the layer will not learn
            an additive bias. (default: :obj:`True`)
        root_weight (bool, optional): If set to :obj:`False`, the layer will
            not add the transformed root node features to the output and the
            option  :attr:`beta` is set to :obj:`False`. (default: :obj:`True`)
        **kwargs (optional): Additional arguments of
            :class:`torch_geometric.nn.conv.MessagePassing`.
    """
    _alpha: OptTensor

    def __init__(
        self,
        in_channels: Union[int, Tuple[int, int]],
        out_channels: int,
        heads: int = 1,
        concat: bool = True,
        beta: bool = False,
        dropout: float = 0.,
        edge_dim: Optional[int] = None,
        bias: bool = True,
        root_weight: bool = True,
        **kwargs,
    ):
        kwargs.setdefault('aggr', 'add')
        super().__init__(node_dim=0, **kwargs)

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.heads = heads
        self.beta = beta and root_weight
        self.root_weight = root_weight
        self.concat = concat
        self.dropout = dropout
        self.edge_dim = edge_dim
        self._alpha = None

        if isinstance(in_channels, int):
            in_channels = (in_channels, in_channels)

        self.lin_key = Linear(in_channels[0], heads * out_channels)
        self.lin_query = Linear(in_channels[1], heads * out_channels)
        self.lin_value = Linear(in_channels[0], heads * out_channels)
        if edge_dim is not None:
            self.lin_edge = Linear(edge_dim, heads * out_channels, bias=False)
        else:
            self.lin_edge = self.register_parameter('lin_edge', None)

        if concat:
            self.lin_skip = Linear(in_channels[1], heads * out_channels,
                                   bias=bias)
            if self.beta:
                self.lin_beta = Linear(3 * heads * out_channels, 1, bias=False)
            else:
                self.lin_beta = self.register_parameter('lin_beta', None)
        else:
            self.lin_skip = Linear(in_channels[1], out_channels, bias=bias)
            if self.beta:
                self.lin_beta = Linear(3 * out_channels, 1, bias=False)
            else:
                self.lin_beta = self.register_parameter('lin_beta', None)

        self.reset_parameters()

[docs]    def reset_parameters(self):
        super().reset_parameters()
        self.lin_key.reset_parameters()
        self.lin_query.reset_parameters()
        self.lin_value.reset_parameters()
        if self.edge_dim:
            self.lin_edge.reset_parameters()
        self.lin_skip.reset_parameters()
        if self.beta:
            self.lin_beta.reset_parameters()

    @overload
    def forward(
        self,
        x: Union[Tensor, PairTensor],
        edge_index: Adj,
        edge_attr: OptTensor = None,
        return_attention_weights: NoneType = None,
    ) -> Tensor:
        pass

    @overload
    def forward(  # noqa: F811
        self,
        x: Union[Tensor, PairTensor],
        edge_index: Tensor,
        edge_attr: OptTensor = None,
        return_attention_weights: bool = None,
    ) -> Tuple[Tensor, Tuple[Tensor, Tensor]]:
        pass

    @overload
    def forward(  # noqa: F811
        self,
        x: Union[Tensor, PairTensor],
        edge_index: SparseTensor,
        edge_attr: OptTensor = None,
        return_attention_weights: bool = None,
    ) -> Tuple[Tensor, SparseTensor]:
        pass

[docs]    def forward(  # noqa: F811
        self,
        x: Union[Tensor, PairTensor],
        edge_index: Adj,
        edge_attr: OptTensor = None,
        return_attention_weights: Optional[bool] = None,
    ) -> Union[
            Tensor,
            Tuple[Tensor, Tuple[Tensor, Tensor]],
            Tuple[Tensor, SparseTensor],
    ]:
        r"""Runs the forward pass of the module.

        Args:
            x (torch.Tensor or (torch.Tensor, torch.Tensor)): The input node
                features.
            edge_index (torch.Tensor or SparseTensor): The edge indices.
            edge_attr (torch.Tensor, optional): The edge features.
                (default: :obj:`None`)
            return_attention_weights (bool, optional): If set to :obj:`True`,
                will additionally return the tuple
                :obj:`(edge_index, attention_weights)`, holding the computed
                attention weights for each edge. (default: :obj:`None`)
        """
        H, C = self.heads, self.out_channels

        if isinstance(x, Tensor):
            x = (x, x)

        query = self.lin_query(x[1]).view(-1, H, C)
        key = self.lin_key(x[0]).view(-1, H, C)
        value = self.lin_value(x[0]).view(-1, H, C)

        # propagate_type: (query: Tensor, key:Tensor, value: Tensor,
        #                  edge_attr: OptTensor)
        out = self.propagate(edge_index, query=query, key=key, value=value,
                             edge_attr=edge_attr)

        alpha = self._alpha
        self._alpha = None

        if self.concat:
            out = out.view(-1, self.heads * self.out_channels)
        else:
            out = out.mean(dim=1)

        if self.root_weight:
            x_r = self.lin_skip(x[1])
            if self.lin_beta is not None:
                beta = self.lin_beta(torch.cat([out, x_r, out - x_r], dim=-1))
                beta = beta.sigmoid()
                out = beta * x_r + (1 - beta) * out
            else:
                out = out + x_r

        if isinstance(return_attention_weights, bool):
            assert alpha is not None
            if isinstance(edge_index, Tensor):
                return out, (edge_index, alpha)
            elif isinstance(edge_index, SparseTensor):
                return out, edge_index.set_value(alpha, layout='coo')
        else:
            return out

    def message(self, query_i: Tensor, key_j: Tensor, value_j: Tensor,
                edge_attr: OptTensor, index: Tensor, ptr: OptTensor,
                size_i: Optional[int]) -> Tensor:

        if self.lin_edge is not None:
            assert edge_attr is not None
            edge_attr = self.lin_edge(edge_attr).view(-1, self.heads,
                                                      self.out_channels)
            key_j = key_j + edge_attr

        alpha = (query_i * key_j).sum(dim=-1) / math.sqrt(self.out_channels)
        alpha = softmax(alpha, index, ptr, size_i)
        self._alpha = alpha
        alpha = F.dropout(alpha, p=self.dropout, training=self.training)

        out = value_j
        if edge_attr is not None:
            out = out + edge_attr

        out = out * alpha.view(-1, self.heads, 1)
        return out

    def __repr__(self) -> str:
        return (f'{self.__class__.__name__}({self.in_channels}, '
                f'{self.out_channels}, heads={self.heads})')

2.详细解释一下

几个重要的参数

in_channels (int or tuple): Size of each input sample, or :obj:`-1` to derive the size from the first input(s) to the forward method. A tuple corresponds to the sizes of source and target dimensionalities.

out_channels (int): Size of each output sample.

heads (int, optional): Number of multi-head-attentions. (default: :obj:`1`)

怎么理解这几个参数?感觉之前自己学习看东西不细,今天又来重复学习了。还是得看英文原版,学习是一点不能走捷径。

class TransformerConv(in_channels:Union[int,Tuple[int,int]],out_channels:int,heads:int = 1,concat:boll=True,beta:bool=False,dropout:float=0.0,edge_dim:Optimal[int]=None,bias:bool=True,root_weight:bool=True,**kwargs)

  • in_channels 表示每个输入样本的大小。如果设置为整数,则表示所有输入样本的大小相同;如果设置为 -1,则表示输入样本的大小将从 forward 方法的第一个输入中推导出来;如果设置为元组,则表示输入样本的大小对应于源维度和目标维度的大小。

  • 其中in_channels:Union[int,Tuple[int,int]]。就是说in_channels后面的参数是可以是int或者是元组的形式。

  • 整数 (int):当 in_channels 是单个整数时,它表示所有输入节点特征的维度相同。例如,如果 in_channels=16,则每个节点的输入特征维度都是 16。

  • 整数元组 (Tuple[int, int]):当 in_channels 是一个整数元组时,它表示源节点和目标节点的输入特征维度分别不同。例如,in_channels=(16, 32) 表示源节点的输入特征维度是 16,而目标节点的输入特征维度是 32。这种情况通常用于异构图或双向图神经网络中。

    import torch
    from torch_geometric.nn import TransformerConv
    
    # 节点特征矩阵
    x = torch.randn((10, 16))  # 10 个节点,每个节点 16 维特征
    
    # 边索引矩阵
    edge_index = torch.tensor([[0, 1, 2, 3], [1, 2, 3, 0]], dtype=torch.long)
    
    # 创建 TransformerConv 层(单一输入维度)
    conv = TransformerConv(in_channels=16, out_channels=32, heads=1, dropout=0.1)
    
    # 应用 TransformerConv 层
    out = conv(x, edge_index)
    print(out.shape)  # 输出特征矩阵,形状为 (10, 32)
    

    这个 forward 方法是图神经网络中用来进行前向传播(forward pass)的关键方法。它接收输入特征和图的结构信息,并计算输出特征。

    参数解释

  • x (Union[Tensor, PairTensor])

    • 输入节点的特征。可以是单个张量,也可以是一个包含两个张量的元组(PairTensor)。
    • 如果是单个张量,形状通常为 [num_nodes,in_channels][num\_nodes, in\_channels][num_nodes,in_channels],表示节点数量和每个节点的特征维度。
    • 如果是 PairTensor,意味着我们可能处理的是异构图,源节点和目标节点的特征可以不同。
  • edge_index (Adj)

    • 图的边索引,表示节点之间的连接关系。通常是形状为 [2,num_edges][2, num\_edges][2,num_edges] 的张量,表示每条边的起点和终点。
    • Adj 类型可以是稀疏矩阵(SparseTensor)或者边索引矩阵(edge_index)。
  • edge_attr (OptTensor)

    • 边的特征,可选参数。如果存在,则形状通常为 [num_edges,edge_feature_dim][num\_edges, edge\_feature\_dim][num_edges,edge_feature_dim],表示每条边的特征。
  • 返回值

    返回值的类型是一个联合类型(Union),具体取决于 return_attention_weights 的值:

  • Tensor

    • 如果 return_attention_weightsFalse,则返回更新后的节点特征。形状通常为 [num_nodes,out_channels][num\_nodes, out\_channels][num_nodes,out_channels]。
  • Tuple[Tensor, Tuple[Tensor, Tensor]]

    • 如果 return_attention_weightsTrue 并且使用的是普通的边索引矩阵,则返回一个包含两个元素的元组:
      • 第一个元素是更新后的节点特征。
      • 第二个元素是包含两个张量的元组,表示注意力权重及其对应的边索引。
  • Tuple[Tensor, SparseTensor]

    • 如果 return_attention_weightsTrue 并且使用的是稀疏矩阵,则返回一个包含两个元素的元组:
      • 第一个元素是更新后的节点特征。
      • 第二个元素是稀疏矩阵形式的注意力权重。
  • return_attention_weights (Optional[bool])

    • 是否返回注意力权重(attention weights)。如果设置为 True,前向传播会返回注意力权重。

out_channels 表示每个输出样本的大小,即经过卷积操作后产生的特征向量的维度大小。

当使用 tg.nn.TransformerConv 时,可以通过以下方式理解 in_channelsout_channels

假设我们有一个图数据集,每个节点都有一个 10 维的特征向量表示。那么在这种情况下:

  • 如果我们想将每个节点的特征向量作为输入,然后使用 tg.nn.TransformerConv 进行卷积操作,那么 in_channels 应该设置为 10,表示每个输入样本的大小为 10。

  • 假设我们想将节点的特征向量转换为一个 16 维的特征向量,那么 out_channels 应该设置为 16,表示每个输出样本的大小为 16,即经过卷积操作后每个节点的特征向量将变为 16 维。

  • tg.nn.TransformerConv 中,heads 参数表示多头注意力的数量。举个例子,如果 heads 参数设置为 4,那么模型将学习 4 组注意力权重,每组权重都用于计算输入的不同子空间的注意力,然后将这些头的输出进行合并以产生最终的输出。

 举个整体的例子

我们有一个输入张量 x,它的形状是 (batch_size, seq_length, input_dim),其中:

  • batch_size 表示批量大小;
  • seq_length 表示序列长度;
  • input_dim 表示输入特征的维度。

现在假设我们使用了 tg.nn.TransformerConv,并设置 heads=2,那么模型将学习两组注意力权重,每组用于计算不同的注意力。输出张量的形状将取决于 out_channels 参数,我们假设 out_channels=64

import torch
import torch_geometric.nn as tg

# 假设输入张量的形状是 (batch_size, seq_length, input_dim)
x = torch.randn(32, 10, 128)  # 32 个样本,每个样本有 10 个时间步,每个时间步有 128 个特征

# 创建 TransformerConv 模型,设置 heads=2,out_channels=64
conv_layer = tg.nn.TransformerConv(in_channels=128, out_channels=64, heads=2)

# 使用模型进行前向传播
output = conv_layer(x)

print("输出张量的形状:", output.shape)

 2.1将特征映射到键值对中

在这里,通过线性变换层 Linear,输入特征被转换成了键(key)、查询(query)和数值(value)的表示形式,以便用于多头自注意力机制。

具体来说:

  • self.lin_key 用于将输入特征(in_channels[0])映射到键的表示形式。
  • self.lin_query 用于将输入特征(in_channels[1])映射到查询的表示形式。
  • self.lin_value 用于将输入特征(in_channels[0])映射到数值的表示形式。

 具体地,假设输入特征的维度是 (batch_size, num_nodes, in_channels),其中 batch_size 是批量大小,num_nodes 是节点数,in_channels 是输入特征的通道数。在映射到键的过程中,线性变换层的权重矩阵将是一个维度为 (in_channels, heads * out_channels) 的矩阵,其中 heads 是注意力头的数量,out_channels 是输出特征的通道数。因此,通过矩阵乘法运算,输入特征将被映射到一个新的特征空间,其维度为 (batch_size, num_nodes, heads, out_channels)。在这个新的特征空间中,每个节点的每个头都有一个键表示。

3.里面的数学表达式意义

具体的卷积操作

  1. 特征变换: 首先,对输入特征进行线性变换。这一步的作用是将输入的节点特征映射到一个新的特征空间。

  2. 计算注意力权重: 使用注意力机制来计算节点之间的注意力权重。这个过程包括以下步骤:

    • 对每个节点的特征向量进行变换,得到查询向量(query)、键向量(key)和值向量(value)。
    • 计算查询向量和键向量的点积,并除以缩放因子(通常是键向量维度的平方根),然后应用 softmax 函数得到注意力权重。
  3. 消息传递: 根据计算得到的注意力权重,将邻居节点的特征加权求和,形成新的节点表示。

  4. 特征更新: 将消息传递得到的新节点表示进行线性变换,并根据需要进行拼接或求平均操作。

  5. 具体的这个W是怎么学到的呢?

  6. 输入的特征维度不一定和输出的特征维度一样。该步骤是作为一个特征提取的过程

  7. 我们看一下论文中一般怎么定义这个过程吧!为了聚合节点信息

  • 43
    点赞
  • 43
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
torch_geometric.loader.DataLoaderPyG中的一个类,用于加载和处理图数据。它可以将多个图批处理成单个巨型图,并提供了一些方便的功能。\[2\] 您可以使用torch_geometric.loader.DataLoader来加载和处理图数据集。例如,您可以创建一个包含torch_geometric.data.Data对象的常规Python列表,并将其传递给DataLoader来批处理这些图数据。\[1\] DataLoader还可以接受一些参数,例如batch_size和shuffle,以控制批处理的大小和数据的顺序。您还可以使用其他可以传递给PyTorch DataLoader的参数,例如num_workers。\[2\] 使用DataLoader加载图数据集的示例代码如下:\[3\] ```python from torch_geometric.datasets import TUDataset from torch_geometric.loader import DataLoader dataset = TUDataset(root='/tmp/ENZYMES', name='ENZYMES', use_node_attr=True) loader = DataLoader(dataset, batch_size=32, shuffle=True) for batch in loader: # 在这里对批处理的图数据进行处理 # 例如,计算每个图的节点维度中的平均节点特征 x = scatter_mean(batch.x, batch.batch, dim=0) print(x.size()) # 输出每个图的节点特征的大小 ``` 通过使用torch_geometric.loader.DataLoader,您可以方便地加载和处理图数据集。它提供了一种简单而有效的方式来处理大规模的图数据。\[3\] #### 引用[.reference_title] - *1* *3* [【PyG】文档总结以及项目经验(持续更新](https://blog.csdn.net/weixin_45928096/article/details/125501673)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v91^koosearch_v1,239^v3^insert_chatgpt"}} ] [.reference_item] - *2* [第十九课.Pytorch-geometric扩展](https://blog.csdn.net/qq_40943760/article/details/120265255)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v91^koosearch_v1,239^v3^insert_chatgpt"}} ] [.reference_item] [ .reference_list ]
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值