定义 RGCNConv 类
class RGCNConv(MessagePassing):
def __init__(
self,
in_channels: Union[int, Tuple[int, int]],
out_channels: int,
num_relations: int,
num_bases: Optional[int] = None,
num_blocks: Optional[int] = None,
aggr: str = 'mean',
root_weight: bool = True,
is_sorted: bool = False,
bias: bool = True,
**kwargs,
):
kwargs.setdefault('aggr', aggr)
super().__init__(node_dim=0, **kwargs)
- 初始化函数定义了图卷积的输入参数和基本设置。
- in_channels 和 out_channels 分别表示输入和输出特征的维度。
- num_relations 表示关系的数量。
- num_bases 和 num_blocks 可选参数用于正则化方案。
- aggr 表示聚合方法(默认值为 ‘mean’)。
- root_weight 和 bias 分别表示是否使用根节点权重和偏置。
参数设置和初始化
if num_bases is not None and num_blocks is not None:
raise ValueError('Can not apply both basis-decomposition and '
'block-diagonal-decomposition at the same time.')
self.in_channels = in_channels
self.out_channels = out_channels
self.num_relations = num_relations
self.num_bases = num_bases
self.num_blocks = num_blocks
self.is_sorted = is_sorted
if isinstance(in_channels, int):
in_channels = (in_channels, in_channels)
self.in_channels_l = in_channels[0]
self._use_segment_matmul_heuristic_output: Optional[bool] = None
if num_bases is not None:
self.weight = Parameter(
torch.empty(num_bases, in_channels[0], out_channels))
self.comp = Parameter(torch.empty(num_relations, num_bases))
elif num_blocks is not None:
assert (in_channels[0] % num_blocks == 0
and out_channels % num_blocks == 0)
self.weight = Parameter(
torch.empty(num_relations, num_blocks,
in_channels[0] // num_blocks,
out_channels // num_blocks))
self.register_parameter('comp', None)
else:
self.weight = Parameter(
torch.empty(num_relations, in_channels[0], out_channels))
self.register_parameter('comp', None)
if root_weight:
self.root = Parameter(torch.empty(in_channels[1], out_channels))
else:
self.register_parameter('root', None)
if bias:
self.bias = Parameter(torch.empty(out_channels))
else:
self.register_parameter('bias', None)
self.reset_parameters()
-
处理 num_bases 和 num_blocks 参数冲突。
-
初始化输入和输出特征维度、关系数量、正则化参数等。
-
根据是否使用 num_bases 或 num_blocks 进行权重初始化(若不需要使用 num_bases 或 num_blocks进行权重初始化,就使用常规的(num_relations, in_channels[0], out_channels)即可)。
-
根据 root_weight 和 bias 初始化根节点权重和偏置。
root_weight 参数用于决定是否在节点特征的更新过程中引入自身特征的权重。(如果 root_weight 设置为 True,则在更新节点特征时会将自身特征也考虑进去。)
bias 参数用于决定是否在节点特征的更新过程中添加偏置项。(如果 bias 设置为 True,则在更新节点特征时会添加一个偏置项。)
参数重置函数
def reset_parameters(self):
super().reset_parameters()
glorot(self.weight)
glorot(self.comp)
glorot(self.root)
zeros(self.bias)
import torch
from torch import Tensor
def glorot(tensor: Tensor) -> Tensor:
if tensor is not None:
torch.nn.init.xavier_uniform_(tensor)
return tensor
forward函数定义
处理输入特征
def forward(self, x: Union[OptTensor, Tuple[OptTensor, Tensor]],
edge_index: Adj, edge_type: OptTensor = None):
x_l: OptTensor = None
if isinstance(x, tuple):
x_l = x[0]
else:
x_l = x
if x_l is None:
x_l = torch.arange(self.in_channels_l, device=self.weight.device)
x_r: Tensor = x_l
if isinstance(x, tuple):
x_r = x[1]
- x_l 表示源节点特征,如果 x 是元组,则取第一个元素,否则直接赋值为 x。
- 如果 x_l 为 None,则生成一个范围在 0 到 self.in_channels_l - 1 的张量。
- x_r 表示目标节点特征,如果 x 是元组,则取第二个元素,否则直接赋值为 x_l。
处理边索引和边类型
size = (x_l.size(0), x_r.size(0))
if isinstance(edge_index, SparseTensor):
edge_type = edge_index.storage.value()
assert edge_type is not None
- size 是一个元组,包含源节点和目标节点的数量。
- 如果 edge_index 是一个稀疏张量,则从中提取边类型。
- 确保 edge_type 不是 None。
初始化输出张量
out = torch.zeros(x_r.size(0), self.out_channels, device=x_r.device)
- 初始化输出张量 out,形状为 [目标节点数, 输出特征维度],所有元素初始化为零。
处理权重矩阵
weight = self.weight
if self.num_bases is not None: # Basis-decomposition
weight = (self.comp @ weight.view(self.num_bases, -1)).view(
self.num_relations, self.in_channels_l, self.out_channels)
处理块对角分解
if self.num_blocks is not None: # Block-diagonal-decomposition
if not torch.is_floating_point(x_r) and self.num_blocks is not None:
raise ValueError('Block-diagonal decomposition not supported '
'for non-continuous input features.')
for i in range(self.num_relations):
tmp = masked_edge_index(edge_index, edge_type == i)
h = self.propagate(tmp, x=x_l, edge_type_ptr=None, size=size)
h = h.view(-1, weight.size(1), weight.size(2))
h = torch.einsum('abc,bcd->abd', h, weight[i])
out = out + h.contiguous().view(-1, self.out_channels)
- 如果使用块对角分解,则对每种关系类型分别处理。
- 根据边类型掩码获取相应的边索引。
- 调用 propagate 方法进行消息传递。
- 使用 torch.einsum 进行块对角矩阵的乘法运算。
- 将结果累加到输出张量中。
初始化并确定是否使用 segment_matmul
else: # No regularization/Basis-decomposition
use_segment_matmul = torch_geometric.backend.use_segment_matmul
if use_segment_matmul is None:
segment_count = scatter(torch.ones_like(edge_type), edge_type,
dim_size=self.num_relations)
self._use_segment_matmul_heuristic_output = (
torch_geometric.backend.use_segment_matmul_heuristic(
num_segments=self.num_relations,
max_segment_size=int(segment_count.max()),
in_channels=self.weight.size(1),
out_channels=self.weight.size(2),
))
assert self._use_segment_matmul_heuristic_output is not None
use_segment_matmul = self._use_segment_matmul_heuristic_output
- 如果 use_segment_matmul 为空,则我们根据输入的边类型数量和特征维度来决定是否使用 segment_matmul。
- scatter 用于计算每种关系类型的数量。——将长度为edge_type的全一张量根据edge_type来分散,得到的segement_count为各种edge_type的总数。
segment_count:tensor([1116655, 2626979], device='cuda:0')
edge_type.shape:torch.Size([3743634])
- torch_geometric.backend.use_segment_matmul_heuristic 是一个启发式方法,根据 num_segments(关系类型数)、max_segment_size(最大关系类型数)、输入和输出特征维度来决定是否使用 segment_matmul。
- 结果保存在 self._use_segment_matmul_heuristic_output 中,并赋值给 use_segment_matmul。
处理排序和调用 segment_matmul
if (use_segment_matmul and torch_geometric.typing.WITH_SEGMM
and not is_compiling() and self.num_bases is None
and x_l.is_floating_point()
and isinstance(edge_index, Tensor)):
if not self.is_sorted:
if (edge_type[1:] < edge_type[:-1]).any():
edge_type, perm = index_sort(
edge_type, max_value=self.num_relations)
edge_index = edge_index[:, perm]
edge_type_ptr = index2ptr(edge_type, self.num_relations)
out = self.propagate(edge_index, x=x_l,
edge_type_ptr=edge_type_ptr, size=size)
-
检查是否满足使用 segment_matmul 的条件:
use_segment_matmul 为真。
torch_geometric.typing.WITH_SEGMM 为真。
当前不是在编译过程中。
没有使用基分解(self.num_bases 为 None)。
x_l 是浮点数。
edge_index 是张量。 -
如果满足上述条件并且边类型没有排序,则对边类型进行排序,并相应地调整边索引。
使用edge_type[1:] < edge_type[:-1]判断是否edge_type中的前一个元素均比后一个元素大,从而可得出是否升序排序。那么相同类型的边(关系)将被分段处理,从而最大限度地利用 segment_matmul 的优势,提高乘法操作的效率。
index_sort用于对张量进行排序,并返回排序后的张量及其对应的排序索引。max_value: 可选参数,用于指定张量中元素的最大值。如果提供了这个值,函数可以使用更高效的排序算法。 -
使用 index2ptr 将边类型转换为指针。
edge_type_ptr = index2ptr(edge_type, self.num_relations)
- 调用 propagate 方法进行消息传递,传递参数包括边索引、节点特征和边类型指针。
不使用segment_matmul
else:
for i in range(self.num_relations):
tmp = masked_edge_index(edge_index, edge_type == i)
if not torch.is_floating_point(x_r):
out = out + self.propagate(
tmp,
x=weight[i, x_l],
edge_type_ptr=None,
size=size,
)
else:
h = self.propagate(tmp, x=x_l, edge_type_ptr=None,
size=size)
out = out + (h @ weight[i])
- 如果不使用 segment_matmul,则对每种关系类型分别处理。
- masked_edge_index 根据当前关系类型 i 生成对应的边索引。
if not torch.is_floating_point(x_r):
out = out + root[x_r]
else:
out = out + x_r @ root
message函数
def message(self, x_j: Tensor, edge_type_ptr: OptTensor) -> Tensor:
if (torch_geometric.typing.WITH_SEGMM and not is_compiling()
and edge_type_ptr is not None):
# TODO Re-weight according to edge type degree for `aggr=mean`.
return pyg_lib.ops.segment_matmul(x_j, edge_type_ptr, self.weight)
return x_j
- 检查条件:
torch_geometric.typing.WITH_SEGMM:检查是否支持 segment_matmul 操作。
not is_compiling():检查当前是否处于编译模式。
edge_type_ptr is not None:确保 edge_type_ptr 不为空。
只有在这三个条件都满足的情况下,才会使用 segment_matmul 方法。 - 执行 segment_matmul 操作
pyg_lib.ops.segment_matmul 是一种高效的矩阵乘法操作。
x_j 是源节点特征。
edge_type_ptr 是边类型指针,用于指示每种边类型的起始位置。
self.weight 是每种边类型的权重矩阵。
segment_matmul 方法将根据边类型指针将 x_j 分块,然后与相应的权重矩阵进行矩阵乘法运算,并返回结果。
def segment_matmul(x: Tensor, ptr: Tensor, weight: Tensor) -> Tensor:
这里还有一个 TODO,表示需要根据边类型的度数重新加权以进行 aggr=mean 聚合。这一步还未实现。
- 默认返回源节点特征:
如果不满足上述条件,则直接返回源节点特征 x_j。这是默认情况,不进行任何操作。
message_and_aggregate
def message_and_aggregate(self, adj_t: Adj, x: Tensor) -> Tensor:
if isinstance(adj_t, SparseTensor):
adj_t = adj_t.set_value(None)
return spmm(adj_t, x, reduce=self.aggr)
-
参数:
adj_t: Adj:邻接矩阵(或张量),表示图的边连接信息。类型可以是稀疏张量 (SparseTensor) 或一般张量。
x: Tensor:节点特征张量,维度为 [num_nodes, in_channels],即每个节点的输入特征。 -
检查并处理邻接矩阵类型
if isinstance(adj_t, SparseTensor):检查 adj_t 是否是稀疏张量。
adj_t.set_value(None):如果 adj_t 是稀疏张量,则将其值部分设置为 None。这表示我们只关心稀疏张量的结构部分,而不关心它的值部分。这通常用于表示无权图或将权重视为1的情况。 -
执行稀疏矩阵-稠密矩阵乘法(spmm)——将稀疏邻接矩阵 adj_t 和节点特征矩阵 x 进行乘法运算,得到新的节点特征矩阵
adj_t:邻接矩阵,表示图的结构。——用于将节点的邻居特征聚合到节点自身
x:节点特征矩阵。
reduce=self.aggr:聚合方式,可以是 ‘add’(加法聚合)、‘mean’(平均聚合)或 ‘max’(最大值聚合)。这决定了在聚合消息时使用的具体方法。