[240905] 如何使用 JAX 和 Equinox 构建图卷积网络 | Cascadia 字体家族迎来新成员

如何使用 JAX 和 Equinox 构建图卷积网络

文章介绍如何使用 JAX 和 Equinox 构建图卷积网络 (GNN)。我们将分别使用邻接矩阵和边列表两种方式实现图卷积层,并比较它们的优缺点。

1 使用邻接矩阵

1.1 邻接矩阵表示法

对于包含 NNN 个节点的图,我们可以使用一个 N×NN \times NN×N 的邻接矩阵 AAA 来表示节点之间的关系。 矩阵元素 ai,j∈{0,1}a_{i, j} \in \{0, 1\}ai,j\u200b∈{0,1} 表示从节点 jjj 到节点 iii 是否存在边 (1 表示存在,0 表示不存在)。

1.2 图卷积层实现
import jax.experimental.sparse as jsparse

class GraphConv(eqx.Module):
    linear: nn.Linear

    def __init__(self, hidden_dim: int, *, key: PRNGKeyArray):
        self.linear = nn.Linear(hidden_dim, hidden_dim, key=key)

    def __call__(
        self,
        nodes: Float[Array, "n_nodes hidden_dim"],
        adjacency: Int[jsparse.BCOO, "n_nodes n_nodes"]
    ) -> Float[Array, "n_nodes hidden_dim"]:
        messages = vmap(self.linear)(nodes)
        return adjacency @ messages
1.3 计算过程解释

上述代码中,我们首先对所有节点应用线性变换,然后将结果与邻接矩阵相乘。这等效于对每个节点 iii,将其所有邻居节点 jjj 的特征进行加权求和,其中权重由线性变换矩阵 WWW 决定。

2 使用边列表

2.1 边列表表示法

边列表使用一个 M×2M \times 2M×2 的张量 EEE 来表示图中的边。其中, ek=(j,i)e_k = (j, i)ek\u200b=(j,i) 表示第 kkk 条边是从节点 jjj 到节点 iii。

2.2 图卷积层实现
class GraphConv(eqx.Module):
    linear: nn.Linear

    def __init__(self, hidden_dim: int, *, key: PRNGKeyArray):
        self.linear = nn.Linear(hidden_dim, hidden_dim, key=key)

    def __call__(
        self,
        nodes: Float[Array, "n_nodes hidden_dim"],
        edges: Int[Array, "n_edges 2"],
    ) -> Float[Array, "n_nodes hidden_dim"]:
        messages = vmap(self.linear)(nodes)
        messages = messages[edges[:, 0]]  # 获取源节点特征
        messages = jax.ops.segment_sum(
            data=messages,
            segment_ids=edges[:, 1],
            num_segments=len(nodes),
        )  # 按目标节点聚合特征
        return messages
2.3 代码解析:jax.ops.segment_sum

jax.ops.segment_sum 函数用于根据 segment_ids 对数据进行分组求和。在本例中,我们将所有边的源节点 特征按照目标节点 ID 进行分组求和,从而得到每个目标节点的聚合特征。

2.4 计算节点度数示例
ones = jnp.ones(len(edges), dtype=jnp.int32)
degrees = jax.ops.segment_sum(
    data=ones,
    segment_ids=edges[:, 1],
    num_segments=len(nodes),
)
2.5 边列表表示法的优势
  • 灵活性更高: 可以使用不同的聚合函数 (例如 segment_minsegment_max),以及对边特征进行线性变换。
  • 更易于实现复杂的 GNN 模型: 例如 GAT (图注意力网络)。

3 模型训练

3.1 任务设置:节点排序

我们构建了一个节点排序任务来测试 GNN 模型。首先生成随机图,然后根据节点的聚类系数为每个节点分配一个 分数。

3.2 模型:GCN 和 GAT

我们分别使用邻接矩阵和边列表实现了 GCN 和 GAT 两种 GNN 模型。

3.3 训练结果

在包含 800 个随机图的数据集上进行训练,结果表明 GCN 在该任务上表现略优于 GAT。

4 JIT 优化技巧

4.1 JIT 原理

JAX 的 JIT (Just-In-Time) 编译机制可以显著提高代码运行效率。首次调用函数时,JIT 会将其编译并缓存,下次调用相同函数时直接使用缓存结果。

4.2 图数据形状问题

由于不同图的节点数和边数不同,JIT 缓存机制可能会导致频繁的重新编译。

4.3 解决方案:填充图数据

为了避免频繁的重新编译,我们需要对图数据进行填充,使其形状保持一致。

5 两种表示法的优缺点

  • 邻接矩阵:计算效率高,内存占用少,但灵活性较低。
  • 边列表:灵活性高,易于实现复杂模型,但计算效率和内存占用略逊于邻接矩阵。
进一步阅读

https://github.com/pierrot-lc/gnn-tuto

来源:

https://pierrot-lc.github.io/website/2024/09/02/tuto-gnn.html

Cascadia 字体家族迎来新成员:Cascadia Next SC、TC 和 JP 预发布!

微软开源字体 Cascadia Code 迎来了重大更新!除了原有的英文字体外,现在新增了简体中文 (SC)、繁体中文 (TC) 和日语 (JP) 三种变体,为更多开发者带来更好的编码体验。

Cascadia Next 由微软设计师 @aaronbell 精心打造,目前预发布版本包含以下字符集:

  • 简体中文:ASCII, GB2312 扩展
  • 繁体中文:ASCII, BIG5+
  • 日语:ASCII, Joyo, JIS1, JIS2

需要注意的是,本次预发布版本暂不支持阿拉伯语、希伯来语和 NerdFonts。

微软团队非常重视用户的反馈,希望广大开发者积极尝试新字体,并提出宝贵意见,帮助他们进一步完善 Cascadia Next。

立即体验 Cascadia Next:
https://github.com/microsoft/cascadia-code/releases/tag/cascadia-next

来源:

https://github.com/microsoft/cascadia-code/releases/tag/cascadia-next

更多内容请查阅 : blog-240905


关注微信官方公众号 : oh my x

获取开源软件和 x-cmd 最新用法

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值