【代码解读】torch_geometric.nn.GCNConv

在这里插入图片描述

  • A ^ \mathbf{\hat{A}} A^ 是添加了自环的邻接矩阵,元素 a i j a_{ij} aij可用 1 和 0 表示是否有连边,或等于 e i j e_{ij} eij 连边权值。
  • D ^ \mathbf{\hat{D}} D^ 是对角矩阵,对角元素 d i i d_{ii} dii 是当前节点的度。
  • D ^ − 1 / 2 A ^ D ^ − 1 / 2 \mathbf{\hat{D}}^{-1/2} \mathbf{\hat{A}}\mathbf{\hat{D}}^{-1/2} D^1/2A^D^1/2 中的元素 t i j t_{ij} tij e j , i d ^ i d ^ j \frac{e_{j,i}}{\sqrt{\hat{d}_i}\sqrt{\hat{d}_j}} d^i d^j ej,i。(即:message中确定了权重——边权除以source和target的度的1/2次幂的积)
    在这里插入图片描述

类定义和成员变量

class GCNConv(MessagePassing):
    _cached_edge_index: Optional[OptPairTensor]
    _cached_adj_t: Optional[SparseTensor]
  • _cached_edge_index 和 _cached_adj_t 是用于缓存边索引和邻接矩阵的变量,防止在每次前向传播时重复计算。

构造函数 init

def __init__(
    self,
    in_channels: int,
    out_channels: int,
    improved: bool = False,
    cached: bool = False,
    add_self_loops: Optional[bool] = None,
    normalize: bool = True,
    bias: bool = True,
    **kwargs,
):
    kwargs.setdefault('aggr', 'add')
    super().__init__(**kwargs)

    if add_self_loops is None:
        add_self_loops = normalize

    if add_self_loops and not normalize:
        raise ValueError(f"'{self.__class__.__name__}' does not support "
                         f"adding self-loops to the graph when no "
                         f"on-the-fly normalization is applied")

    self.in_channels = in_channels
    self.out_channels = out_channels
    self.improved = improved
    self.cached = cached
    self.add_self_loops = add_self_loops
    self.normalize = normalize

    self._cached_edge_index = None
    self._cached_adj_t = None

    self.lin = Linear(in_channels, out_channels, bias=False,
                      weight_initializer='glorot')

    if bias:
        self.bias = Parameter(torch.empty(out_channels))
    else:
        self.register_parameter('bias', None)

    self.reset_parameters()
  • 检查是否在不归一化的情况下添加自环,如果是则抛出异常。
  • weight_initializer=‘glorot’ 指定使用 Xavier 初始化方法来初始化权重。

重置参数 reset_parameters

def reset_parameters(self):
    super().reset_parameters()
    self.lin.reset_parameters()
    zeros(self.bias)
    self._cached_edge_index = None
    self._cached_adj_t = None

前向传播 forward

def forward(self, x: Tensor, edge_index: Adj, edge_weight: OptTensor = None) -> Tensor:
    if isinstance(x, (tuple, list)):
        raise ValueError(f"'{self.__class__.__name__}' received a tuple "
                         f"of node features as input while this layer "
                         f"does not support bipartite message passing. "
                         f"Please try other layers such as 'SAGEConv' or "
                         f"'GraphConv' instead")

    if self.normalize:
        if isinstance(edge_index, Tensor):
            cache = self._cached_edge_index
            if cache is None:
                edge_index, edge_weight = gcn_norm(
                    edge_index, edge_weight, x.size(self.node_dim),
                    self.improved, self.add_self_loops, self.flow, x.dtype)
                if self.cached:
                    self._cached_edge_index = (edge_index, edge_weight)
            else:
                edge_index, edge_weight = cache[0], cache[1]

        elif isinstance(edge_index, SparseTensor):
            cache = self._cached_adj_t
            if cache is None:
                edge_index = gcn_norm(
                    edge_index, edge_weight, x.size(self.node_dim),
                    self.improved, self.add_self_loops, self.flow, x.dtype)
                if self.cached:
                    self._cached_adj_t = edge_index
            else:
                edge_index = cache

    x = self.lin(x)

    # propagate_type: (x: Tensor, edge_weight: OptTensor)
    out = self.propagate(edge_index, x=x, edge_weight=edge_weight)

    if self.bias is not None:
        out = out + self.bias

    return out

forward 方法定义了前向传播过程。

  1. 如果输入特征是元组或列表,则抛出异常,因为该层不支持二分图消息传递。
  2. 如果需要归一化,并且 edge_index 是张量,检查是否有缓存。如果没有缓存,计算归一化的边索引和边权重,并缓存结果。
  3. 如果 edge_index 是稀疏张量,同样处理缓存逻辑。
  4. 对节点特征进行线性变换。
  5. 调用 propagate 方法进行消息传递。
  6. 如果存在偏置,加上偏置。
  7. 返回输出特征。

消息函数 message

def message(self, x_j: Tensor, edge_weight: OptTensor) -> Tensor:
    return x_j if edge_weight is None else edge_weight.view(-1, 1) * x_j
  1. message 方法定义了如何从源节点到目标节点传递消息。
  2. 如果没有边权重,直接返回源节点特征;如果有边权重,返回加权后的源节点特征。

消息和聚合 message_and_aggregate

def message_and_aggregate(self, adj_t: Adj, x: Tensor) -> Tensor:
    return spmm(adj_t, x, reduce=self.aggr)
  1. message_and_aggregate 方法用于稀疏矩阵的消息传递和聚合。
  2. 使用稀疏矩阵乘法 spmm 进行消息传递和聚合。

GCNConv类实现了GCN的核心思想

通过消息传递和特征聚合来更新节点的表示。它通过以下几个步骤实现:

  1. 添加自环,使得节点可以聚合自身特征。
  2. 对节点特征进行线性变换。
  3. 计算归一化因子,以保证特征的尺度一致。
  4. 聚合邻居节点的特征,并使用归一化因子进行归一化。
  5. 返回聚合后的节点特征。
    在这里插入图片描述

规范化方法 gcn_norm

重载 gcn_norm 函数的定义

@torch.jit._overload
def gcn_norm(  # noqa: F811
        edge_index, edge_weight, num_nodes, improved, add_self_loops, flow,
        dtype):
    # type: (Tensor, OptTensor, Optional[int], bool, bool, str, Optional[int]) -> OptPairTensor  # noqa
    pass

@torch.jit._overload
def gcn_norm(  # noqa: F811
        edge_index, edge_weight, num_nodes, improved, add_self_loops, flow,
        dtype):
    # type: (SparseTensor, OptTensor, Optional[int], bool, bool, str, Optional[int]) -> SparseTensor  # noqa
    pass
  • 通过 @torch.jit._overload 装饰器,定义了两个重载函数的签名,支持不同类型的 edge_index 输入

实际 gcn_norm 函数的实现

def gcn_norm(  # noqa: F811
    edge_index: Adj,
    edge_weight: OptTensor = None,
    num_nodes: Optional[int] = None,
    improved: bool = False,
    add_self_loops: bool = True,
    flow: str = "source_to_target",
    dtype: Optional[torch.dtype] = None,
):
    fill_value = 2. if improved else 1.

SparseTensor 类型的 edge_index和torch.sparse 类型的 edge_index:

  • 实现方式:

torch.sparse 类型是 PyTorch 的原生稀疏张量格式,适用于通用的稀疏矩阵操作。
torch_geometric 的 SparseTensor 是为图神经网络优化的稀疏矩阵格式,包含了许多图操作的优化。

  • 使用场景:

torch.sparse 适用于需要进行稀疏矩阵乘法等通用操作的场景。
torch_geometric 的 SparseTensor 适用于图神经网络,特别是需要处理大规模图数据的场景。

  • 功能支持:

torch.sparse 提供了基本的稀疏张量操作,如矩阵乘法、求和等。
torch_geometric 的 SparseTensor 提供了丰富的图操作支持,如自环添加、归一化、度数计算等。


处理 SparseTensor 类型的 edge_index

if isinstance(edge_index, SparseTensor):
    assert edge_index.size(0) == edge_index.size(1)

    adj_t = edge_index

    if not adj_t.has_value():
        adj_t = adj_t.fill_value(1., dtype=dtype)
    if add_self_loops:
        adj_t = torch_sparse.fill_diag(adj_t, fill_value)

    deg = torch_sparse.sum(adj_t, dim=1)
    deg_inv_sqrt = deg.pow_(-0.5)
    deg_inv_sqrt.masked_fill_(deg_inv_sqrt == float('inf'), 0.)
    adj_t = torch_sparse.mul(adj_t, deg_inv_sqrt.view(-1, 1))
    adj_t = torch_sparse.mul(adj_t, deg_inv_sqrt.view(1, -1))

    return adj_t

如果 edge_index 是 SparseTensor 类型,执行以下步骤:

  1. 确保 edge_index 是方阵(行数等于列数)。
  2. 如果 adj_t 没有值,则填充为 1。
  3. 如果需要添加自环,则填充对角线。
  4. 计算度矩阵并求平方根的倒数,处理无穷值。
  5. 对 adj_t 进行归一化处理。

处理 torch.sparse 类型的 edge_index

if is_torch_sparse_tensor(edge_index):
    assert edge_index.size(0) == edge_index.size(1)

    if edge_index.layout == torch.sparse_csc:
        raise NotImplementedError("Sparse CSC matrices are not yet "
                                  "supported in 'gcn_norm'")

    adj_t = edge_index
    if add_self_loops:
        adj_t, _ = add_self_loops_fn(adj_t, None, fill_value, num_nodes)

    edge_index, value = to_edge_index(adj_t)
    col, row = edge_index[0], edge_index[1]

    deg = scatter(value, col, 0, dim_size=num_nodes, reduce='sum')
    deg_inv_sqrt = deg.pow_(-0.5)
    deg_inv_sqrt.masked_fill_(deg_inv_sqrt == float('inf'), 0)
    value = deg_inv_sqrt[row] * value * deg_inv_sqrt[col]

    return set_sparse_value(adj_t, value), None

如果 edge_index 是 PyTorch 稀疏张量,执行以下步骤:

  1. 确保 edge_index 是方阵。
  2. 如果 edge_index 使用稀疏列压缩格式(CSC),则抛出未实现错误。
  3. 如果需要添加自环,则添加。
  4. 将稀疏张量转换为边索引格式。
  5. 计算度矩阵并进行归一化处理。

处理普通张量的 edge_index

assert flow in ['source_to_target', 'target_to_source']
num_nodes = maybe_num_nodes(edge_index, num_nodes)

if add_self_loops:
    edge_index, edge_weight = add_remaining_self_loops(
        edge_index, edge_weight, fill_value, num_nodes)

if edge_weight is None:
    edge_weight = torch.ones((edge_index.size(1), ), dtype=dtype,
                             device=edge_index.device)

row, col = edge_index[0], edge_index[1]
idx = col if flow == 'source_to_target' else row
deg = scatter(edge_weight, idx, dim=0, dim_size=num_nodes, reduce='sum')
deg_inv_sqrt = deg.pow_(-0.5)
deg_inv_sqrt.masked_fill_(deg_inv_sqrt == float('inf'), 0)
edge_weight = deg_inv_sqrt[row] * edge_weight * deg_inv_sqrt[col]

return edge_index, edge_weight
  • 确保 flow 参数在允许的范围内。
  • 计算节点数。
    maybe_num_nodestorch_geometric.utils.num_nodes 模块中的一个函数,用于推断图的节点数量。
  • 如果需要添加自环,则添加。
  • 如果 edge_weight 为空,则初始化为全 1 的张量。
  • 根据 flow 参数选择索引进行度矩阵计算并归一化处理。
    PS:当edge_weight均为1的时候,scatter得到每个节点的度赋值给deg,deg为从0到num_nodes的下标对应位置的节点度(但当edge_weight不为1时,得到的就是节点的强度)。
    PS:deg_inv_sqrt[row] * edge_weight * deg_inv_sqrt[col] 得到的是归一化后的权重。

scatter函数 —(用于将某个张量 src 的值分散到另一个张量 out 中的函数)

torch_scatter.scatter(src, index, dim=-1, out=None, dim_size=None, reduce='sum')

参数说明

  • src:源张量,包含要散布的值。
  • index:索引张量,指定 src 中的元素将散布到 out 张量中的哪些位置。
  • dim:沿着哪个维度进行散布操作。
  • out:目标张量,如果不提供,将创建一个新张量。
  • dim_size:目标张量的大小,如果不提供,将根据 index 自动计算。
  • reduce:聚合操作的类型,例如 ‘sum’(求和)、‘mean’(平均)、‘max’(最大值)等。
  • 24
    点赞
  • 28
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值