【代码解读】torch_geometric.nn.RGCNConv

定义 RGCNConv 类

class RGCNConv(MessagePassing):
    def __init__(
        self,
        in_channels: Union[int, Tuple[int, int]],
        out_channels: int,
        num_relations: int,
        num_bases: Optional[int] = None,
        num_blocks: Optional[int] = None,
        aggr: str = 'mean',
        root_weight: bool = True,
        is_sorted: bool = False,
        bias: bool = True,
        **kwargs,
    ):
        kwargs.setdefault('aggr', aggr)
        super().__init__(node_dim=0, **kwargs)
  • 初始化函数定义了图卷积的输入参数和基本设置。
  • in_channels 和 out_channels 分别表示输入和输出特征的维度。
  • num_relations 表示关系的数量。
  • num_bases 和 num_blocks 可选参数用于正则化方案。
  • aggr 表示聚合方法(默认值为 ‘mean’)。
  • root_weight 和 bias 分别表示是否使用根节点权重和偏置。

参数设置和初始化

        if num_bases is not None and num_blocks is not None:
            raise ValueError('Can not apply both basis-decomposition and '
                             'block-diagonal-decomposition at the same time.')

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.num_relations = num_relations
        self.num_bases = num_bases
        self.num_blocks = num_blocks
        self.is_sorted = is_sorted

        if isinstance(in_channels, int):
            in_channels = (in_channels, in_channels)
        self.in_channels_l = in_channels[0]

        self._use_segment_matmul_heuristic_output: Optional[bool] = None

        if num_bases is not None:
            self.weight = Parameter(
                torch.empty(num_bases, in_channels[0], out_channels))
            self.comp = Parameter(torch.empty(num_relations, num_bases))

        elif num_blocks is not None:
            assert (in_channels[0] % num_blocks == 0
                    and out_channels % num_blocks == 0)
            self.weight = Parameter(
                torch.empty(num_relations, num_blocks,
                            in_channels[0] // num_blocks,
                            out_channels // num_blocks))
            self.register_parameter('comp', None)

        else:
            self.weight = Parameter(
                torch.empty(num_relations, in_channels[0], out_channels))
            self.register_parameter('comp', None)

        if root_weight:
            self.root = Parameter(torch.empty(in_channels[1], out_channels))
        else:
            self.register_parameter('root', None)

        if bias:
            self.bias = Parameter(torch.empty(out_channels))
        else:
            self.register_parameter('bias', None)

        self.reset_parameters()
  • 处理 num_bases 和 num_blocks 参数冲突。

  • 初始化输入和输出特征维度、关系数量、正则化参数等。

  • 根据是否使用 num_bases 或 num_blocks 进行权重初始化(若不需要使用 num_bases 或 num_blocks进行权重初始化,就使用常规的(num_relations, in_channels[0], out_channels)即可)。

  • 根据 root_weight 和 bias 初始化根节点权重和偏置。

    root_weight 参数用于决定是否在节点特征的更新过程中引入自身特征的权重。(如果 root_weight 设置为 True,则在更新节点特征时会将自身特征也考虑进去。)

    bias 参数用于决定是否在节点特征的更新过程中添加偏置项。(如果 bias 设置为 True,则在更新节点特征时会添加一个偏置项。)

参数重置函数

def reset_parameters(self):
    super().reset_parameters()
    glorot(self.weight)
    glorot(self.comp)
    glorot(self.root)
    zeros(self.bias)

在这里插入图片描述

import torch
from torch import Tensor

def glorot(tensor: Tensor) -> Tensor:
    if tensor is not None:
        torch.nn.init.xavier_uniform_(tensor)
    return tensor

forward函数定义

处理输入特征

def forward(self, x: Union[OptTensor, Tuple[OptTensor, Tensor]],
            edge_index: Adj, edge_type: OptTensor = None):   
            
	x_l: OptTensor = None
	if isinstance(x, tuple):
		x_l = x[0]
	else:
		x_l = x
	if x_l is None:
	    x_l = torch.arange(self.in_channels_l, device=self.weight.device)

	x_r: Tensor = x_l
	if isinstance(x, tuple):
	    x_r = x[1]
  • x_l 表示源节点特征,如果 x 是元组,则取第一个元素,否则直接赋值为 x。
  • 如果 x_l 为 None,则生成一个范围在 0 到 self.in_channels_l - 1 的张量。
  • x_r 表示目标节点特征,如果 x 是元组,则取第二个元素,否则直接赋值为 x_l。

处理边索引和边类型

size = (x_l.size(0), x_r.size(0))
if isinstance(edge_index, SparseTensor):
    edge_type = edge_index.storage.value()
assert edge_type is not None
  • size 是一个元组,包含源节点和目标节点的数量。
  • 如果 edge_index 是一个稀疏张量,则从中提取边类型。
  • 确保 edge_type 不是 None。

初始化输出张量

out = torch.zeros(x_r.size(0), self.out_channels, device=x_r.device)
  • 初始化输出张量 out,形状为 [目标节点数, 输出特征维度],所有元素初始化为零。

处理权重矩阵

weight = self.weight
if self.num_bases is not None:  # Basis-decomposition
    weight = (self.comp @ weight.view(self.num_bases, -1)).view(
        self.num_relations, self.in_channels_l, self.out_channels)

处理块对角分解

if self.num_blocks is not None:  # Block-diagonal-decomposition
    if not torch.is_floating_point(x_r) and self.num_blocks is not None:
        raise ValueError('Block-diagonal decomposition not supported '
                         'for non-continuous input features.')

    for i in range(self.num_relations):
        tmp = masked_edge_index(edge_index, edge_type == i)
        h = self.propagate(tmp, x=x_l, edge_type_ptr=None, size=size)
        h = h.view(-1, weight.size(1), weight.size(2))
        h = torch.einsum('abc,bcd->abd', h, weight[i])
        out = out + h.contiguous().view(-1, self.out_channels)
  • 如果使用块对角分解,则对每种关系类型分别处理。
  • 根据边类型掩码获取相应的边索引。
  • 调用 propagate 方法进行消息传递。
  • 使用 torch.einsum 进行块对角矩阵的乘法运算。
  • 将结果累加到输出张量中。

初始化并确定是否使用 segment_matmul

else:  # No regularization/Basis-decomposition
    use_segment_matmul = torch_geometric.backend.use_segment_matmul
    if use_segment_matmul is None:
        segment_count = scatter(torch.ones_like(edge_type), edge_type,
                                dim_size=self.num_relations)
        self._use_segment_matmul_heuristic_output = (
            torch_geometric.backend.use_segment_matmul_heuristic(
                num_segments=self.num_relations,
                max_segment_size=int(segment_count.max()),
                in_channels=self.weight.size(1),
                out_channels=self.weight.size(2),
            ))
        assert self._use_segment_matmul_heuristic_output is not None
        use_segment_matmul = self._use_segment_matmul_heuristic_output
  • 如果 use_segment_matmul 为空,则我们根据输入的边类型数量和特征维度来决定是否使用 segment_matmul。
  • scatter 用于计算每种关系类型的数量。——将长度为edge_type的全一张量根据edge_type来分散,得到的segement_count为各种edge_type的总数。
segment_count:tensor([1116655, 2626979], device='cuda:0')
edge_type.shape:torch.Size([3743634])
  • torch_geometric.backend.use_segment_matmul_heuristic 是一个启发式方法,根据 num_segments(关系类型数)、max_segment_size(最大关系类型数)、输入和输出特征维度来决定是否使用 segment_matmul。
  • 结果保存在 self._use_segment_matmul_heuristic_output 中,并赋值给 use_segment_matmul。

处理排序和调用 segment_matmul

if (use_segment_matmul and torch_geometric.typing.WITH_SEGMM
        and not is_compiling() and self.num_bases is None
        and x_l.is_floating_point()
        and isinstance(edge_index, Tensor)):
    if not self.is_sorted:
        if (edge_type[1:] < edge_type[:-1]).any():
            edge_type, perm = index_sort(
                edge_type, max_value=self.num_relations)
            edge_index = edge_index[:, perm]
    edge_type_ptr = index2ptr(edge_type, self.num_relations)
    out = self.propagate(edge_index, x=x_l,
                         edge_type_ptr=edge_type_ptr, size=size)
  • 检查是否满足使用 segment_matmul 的条件:

    use_segment_matmul 为真。
    torch_geometric.typing.WITH_SEGMM 为真。
    当前不是在编译过程中。
    没有使用基分解(self.num_bases 为 None)。
    x_l 是浮点数。
    edge_index 是张量。

  • 如果满足上述条件并且边类型没有排序,则对边类型进行排序,并相应地调整边索引。
    使用edge_type[1:] < edge_type[:-1]判断是否edge_type中的前一个元素均比后一个元素大,从而可得出是否升序排序。那么相同类型的边(关系)将被分段处理,从而最大限度地利用 segment_matmul 的优势,提高乘法操作的效率。
    index_sort用于对张量进行排序,并返回排序后的张量及其对应的排序索引。max_value: 可选参数,用于指定张量中元素的最大值。如果提供了这个值,函数可以使用更高效的排序算法。

  • 使用 index2ptr 将边类型转换为指针。

在这里插入图片描述

edge_type_ptr = index2ptr(edge_type, self.num_relations)

在这里插入图片描述
在这里插入图片描述

  • 调用 propagate 方法进行消息传递,传递参数包括边索引、节点特征和边类型指针。
    在这里插入图片描述

不使用segment_matmul

else:
    for i in range(self.num_relations):
        tmp = masked_edge_index(edge_index, edge_type == i)
        if not torch.is_floating_point(x_r):
            out = out + self.propagate(
                tmp,
                x=weight[i, x_l],
                edge_type_ptr=None,
                size=size,
            )
        else:
            h = self.propagate(tmp, x=x_l, edge_type_ptr=None,
                               size=size)
            out = out + (h @ weight[i])
  • 如果不使用 segment_matmul,则对每种关系类型分别处理。
  • masked_edge_index 根据当前关系类型 i 生成对应的边索引。
    在这里插入图片描述
if not torch.is_floating_point(x_r):
    out = out + root[x_r]

在这里插入图片描述

else:
    out = out + x_r @ root

在这里插入图片描述

message函数

def message(self, x_j: Tensor, edge_type_ptr: OptTensor) -> Tensor:
    if (torch_geometric.typing.WITH_SEGMM and not is_compiling()
            and edge_type_ptr is not None):
        # TODO Re-weight according to edge type degree for `aggr=mean`.
        return pyg_lib.ops.segment_matmul(x_j, edge_type_ptr, self.weight)

    return x_j
  • 检查条件:
    torch_geometric.typing.WITH_SEGMM:检查是否支持 segment_matmul 操作。
    not is_compiling():检查当前是否处于编译模式。
    edge_type_ptr is not None:确保 edge_type_ptr 不为空。
    只有在这三个条件都满足的情况下,才会使用 segment_matmul 方法。
  • 执行 segment_matmul 操作
    pyg_lib.ops.segment_matmul 是一种高效的矩阵乘法操作。
    x_j 是源节点特征。
    edge_type_ptr 是边类型指针,用于指示每种边类型的起始位置。
    self.weight 是每种边类型的权重矩阵。
    segment_matmul 方法将根据边类型指针将 x_j 分块,然后与相应的权重矩阵进行矩阵乘法运算,并返回结果。
def segment_matmul(x: Tensor, ptr: Tensor, weight: Tensor) -> Tensor:

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

这里还有一个 TODO,表示需要根据边类型的度数重新加权以进行 aggr=mean 聚合。这一步还未实现。
  • 默认返回源节点特征:
    如果不满足上述条件,则直接返回源节点特征 x_j。这是默认情况,不进行任何操作。

message_and_aggregate

def message_and_aggregate(self, adj_t: Adj, x: Tensor) -> Tensor:
    if isinstance(adj_t, SparseTensor):
        adj_t = adj_t.set_value(None)
    return spmm(adj_t, x, reduce=self.aggr)
  • 参数:

    adj_t: Adj:邻接矩阵(或张量),表示图的边连接信息。类型可以是稀疏张量 (SparseTensor) 或一般张量。
    x: Tensor:节点特征张量,维度为 [num_nodes, in_channels],即每个节点的输入特征。

  • 检查并处理邻接矩阵类型
    if isinstance(adj_t, SparseTensor):检查 adj_t 是否是稀疏张量。
    adj_t.set_value(None):如果 adj_t 是稀疏张量,则将其值部分设置为 None。这表示我们只关心稀疏张量的结构部分,而不关心它的值部分。这通常用于表示无权图或将权重视为1的情况。

  • 执行稀疏矩阵-稠密矩阵乘法(spmm)——将稀疏邻接矩阵 adj_t 和节点特征矩阵 x 进行乘法运算,得到新的节点特征矩阵
    adj_t:邻接矩阵,表示图的结构。——用于将节点的邻居特征聚合到节点自身
    x:节点特征矩阵。
    reduce=self.aggr:聚合方式,可以是 ‘add’(加法聚合)、‘mean’(平均聚合)或 ‘max’(最大值聚合)。这决定了在聚合消息时使用的具体方法。

  • 27
    点赞
  • 19
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值