NVIDIA cuEquivariance 详细教程:引言与概述

NVIDIA cuEquivariance 详细教程:引言与概述

在这里插入图片描述

文章目录

1.1 什么是 NVIDIA cuEquivariance

NVIDIA cuEquivariance 是一个专为构建高性能等变神经网络而设计的 Python 库,它通过分段张量积的方式实现了对称性的保持。在深度学习领域,保持数据的内在对称性是提高模型效率和泛化能力的关键因素,而 cuEquivariance 正是为解决这一挑战而生。

作为 NVIDIA 开发的高性能计算库,cuEquivariance 提供了一套全面的 API 来描述分段张量积,并配备了经过优化的 CUDA 内核来执行这些操作。这使得研究人员和开发者能够构建既保持几何对称性又具有高计算效率的深度学习模型。

cuEquivariance 的一个显著特点是它提供了与主流深度学习框架的无缝集成。它包含了 PyTorch 和 JAX 的绑定,使得用户可以在自己熟悉的框架中轻松使用等变神经网络的强大功能。这种灵活性使得 cuEquivariance 成为各种科学计算和深度学习应用的理想选择。

值得一提的是,cuEquivariance 的前端是开源的,在 Apache 2.0 许可下可在 GitHub 上获取,这进一步促进了社区协作和技术创新。

1.2 等变神经网络的基本概念

等变性(Equivariance)是一个数学概念,它描述了当输入数据发生某种变换时,输出也会以可预测的方式发生相应变换的特性。简单来说,等变性是"尊重对称性"这一概念的数学形式化。

在深度学习中,等变神经网络是指那些能够保持输入数据中特定对称性的网络结构。例如,如果我们旋转一个图像,一个旋转等变的网络会产生相应旋转的特征表示,而不是完全不同的表示。这与传统神经网络不同,传统网络通常需要通过数据增强等技术来学习这种不变性。

等变神经网络的核心优势在于它们能够将物理世界的对称性直接编码到网络架构中,而不是让网络从数据中学习这些对称性。这带来了几个重要好处:

  1. 数据效率:由于网络已经"知道"了数据的对称性,它需要更少的训练数据来学习有效的表示。
  2. 泛化能力:等变网络通常在面对未见过的数据时表现更好,因为它们能够利用数据的内在对称性。
  3. 解释性:等变网络的行为更加可预测和可解释,因为它们遵循明确的数学原则。
  4. 计算效率:在某些情况下,等变结构可以减少需要学习的参数数量,从而提高计算效率。

在物理学和化学等领域,许多系统表现出对三维空间中旋转和平移的等变性。例如,分子的能量不应该依赖于它在空间中的方向。通过使用等变神经网络,我们可以确保模型尊重这些物理约束,从而获得更准确和物理上合理的预测。

1.3 cuEquivariance 的主要特性和优势

NVIDIA cuEquivariance 提供了一系列强大的特性和优势,使其成为构建等变神经网络的首选工具:

高性能 CUDA 加速:cuEquivariance 利用 NVIDIA GPU 的强大计算能力,通过优化的 CUDA 内核实现了高效的张量操作。这使得即使是复杂的等变网络也能实现实时推理和快速训练。

灵活的 API 设计:库提供了直观而强大的 API,使用户能够轻松定义和操作群表示。Irreps 类允许用户以简洁的方式描述复杂的表示结构,而 SegmentedTensorProduct 类则提供了构建自定义张量积的灵活性。

多框架支持:cuEquivariance 提供了与 PyTorch 和 JAX 的无缝集成,使用户可以在自己熟悉的深度学习框架中使用等变网络。这种多框架支持极大地降低了学习曲线,并促进了现有项目的集成。

全面的群表示支持:库支持多种常见的群及其表示,包括 SO(3)(三维旋转群)、O(3)(三维旋转和反射群)以及 SU(2)(特殊幺正群)。这使得 cuEquivariance 适用于广泛的科学和工程应用。

优化的内存管理:通过精心设计的数据布局和内存访问模式,cuEquivariance 实现了高效的内存利用,减少了在大规模模型训练中常见的内存瓶颈。

丰富的预定义描述符:库提供了多种预定义的张量积描述符,如线性层、球谐函数和旋转操作,简化了常见等变操作的实现。

实验性高级功能:cuEquivariance 还包含了一些实验性的高级功能,如 JIT 编译内核和融合的散射/收集操作,为追求极致性能的用户提供了更多选择。

开源社区支持:作为一个开源项目,cuEquivariance 受益于活跃的开发者社区,不断改进和扩展其功能。

这些特性和优势使 cuEquivariance 成为一个强大而灵活的工具,能够满足从基础研究到工业应用的各种等变神经网络需求。

1.4 cuEquivariance 的应用场景

NVIDIA cuEquivariance 的强大功能使其在多个领域有着广泛的应用。以下是一些主要的应用场景:

分子动力学与药物发现:在分子模拟和药物设计中,分子的性质不应依赖于它们在空间中的方向。cuEquivariance 可以构建尊重这种旋转等变性的模型,从而更准确地预测分子性质、相互作用和反应动力学。这对于加速药物发现过程和理解复杂生物系统至关重要。

材料科学:新材料的设计和分析需要理解原子和分子在三维空间中的排列。等变神经网络可以从原子坐标直接学习材料性质,而不需要人工设计特征。这使得研究人员能够更快地探索材料空间,发现具有所需性质的新材料。

量子化学计算:量子系统通常表现出对称性,cuEquivariance 可以构建尊重这些对称性的模型,提高量子化学计算的准确性和效率。这对于理解复杂分子系统的电子结构和性质非常有价值。

3D 计算机视觉:在点云处理、3D 物体识别和场景理解等任务中,物体的身份不应依赖于它们的方向。等变神经网络可以自然地处理这种旋转不变性,提高识别准确率和泛化能力。

机器人学与控制系统:机器人需要在三维空间中导航和操作物体。等变神经网络可以帮助机器人更好地理解空间关系,实现更精确的运动规划和物体操作。

气象学与流体动力学:天气预报和流体模拟涉及复杂的三维流场,这些流场通常表现出某种形式的对称性。cuEquivariance 可以构建尊重这些对称性的模型,提高预测准确性。

粒子物理学:高能物理实验产生的数据通常具有旋转对称性。等变神经网络可以更有效地分析这些数据,帮助物理学家发现新粒子或验证理论预测。

蛋白质结构预测:蛋白质的三维结构对其功能至关重要。等变神经网络可以从氨基酸序列预测蛋白质结构,同时尊重蛋白质折叠过程中的物理约束。

这些应用场景展示了 cuEquivariance 的多功能性和在科学计算中的重要价值。通过将物理对称性直接编码到神经网络架构中,cuEquivariance 使模型能够更准确、更高效地解决复杂的科学和工程问题。

1.5 本教程的内容和结构

本教程旨在提供一个全面而深入的 NVIDIA cuEquivariance 指南,从基础概念到高级应用,帮助读者充分利用这个强大的库来构建高性能等变神经网络。无论您是刚接触等变神经网络的初学者,还是寻求提高模型性能的经验丰富的研究人员,本教程都能为您提供有价值的信息和实用技巧。

教程的结构如下:

首先,我们从安装与部署开始,详细介绍如何在不同环境中安装 cuEquivariance 及其依赖项。我们将覆盖系统要求、使用 pip 安装的步骤、从源代码构建的方法,以及如何验证安装是否成功。我们还将讨论常见的安装问题及其解决方案,确保您能够顺利开始使用这个库。

接下来,我们深入探讨群论基础与表示理论,这是理解等变神经网络的理论基础。我们将以直观的方式解释群、群表示和不可约表示的概念,特别关注 SO(3) 和 O(3) 群及其在物理系统中的应用。通过具体的例子,如应力张量的不可约表示分解,我们将展示这些抽象概念如何应用于实际问题。

在掌握了理论基础后,我们将详细介绍 cuEquivariance 的核心组件,包括 Irreps 类、数据布局和等变张量积。我们将通过丰富的代码示例展示如何创建和操作这些组件,为构建复杂的等变网络奠定基础。

随后,我们将通过一系列基本示例高级功能的讲解,展示如何使用 cuEquivariance 构建各种等变神经网络。从简单的等变线性层到复杂的分段张量积,我们将提供详细的代码示例和解释,帮助您理解每个步骤的原理和实现。所有代码示例都将包含中文注释,使您能够更容易地理解和应用这些概念。

为了展示 cuEquivariance 在实际中的应用,我们将介绍几个实际应用案例,包括球谐函数与旋转等变性、构建等变神经网络模型、分子动力学中的应用和 3D 点云处理。这些案例将展示如何将理论知识应用于解决实际问题,并提供性能优化的技巧。

我们还将探讨 cuEquivariance 的一些Beta 特性与实验性功能,如 JIT 内核和融合的散射/收集内核,这些功能可以进一步提高模型的性能。

最后,我们将提供故障排除与常见问题的指南,帮助您解决在使用 cuEquivariance 过程中可能遇到的问题,以及总结与进阶资源,为您指明进一步学习和探索的方向。

通过本教程,您将获得构建和优化等变神经网络的全面知识和实践经验,能够将 cuEquivariance 的强大功能应用于您自己的研究和应用中。让我们开始这个激动人心的学习之旅吧!

NVIDIA cuEquivariance 详细教程:安装与部署

2.1 系统要求

在开始安装 NVIDIA cuEquivariance 之前,需要确保您的系统满足以下要求:

硬件要求

  • NVIDIA GPU(推荐 Volta、Turing、Ampere 或更新架构)
  • 足够的 GPU 内存(至少 4GB,对于大型模型建议 8GB 或更多)
  • 足够的系统内存(至少 8GB,建议 16GB 或更多)

软件要求

  • 操作系统:
    • Linux(推荐 Ubuntu 18.04 或更高版本)
    • 对于 cuequivariance-ops-torch-* 包:支持 Linux x86_64/aarch64
    • 对于 cuequivariance-ops-jax-cu12 包:仅支持 Linux x86_64
  • CUDA 工具包:
    • 对于 cu11 版本:CUDA 11.x
    • 对于 cu12 版本:CUDA 12.x
  • Python 版本:3.8 或更高版本(对于 aarch64 架构,仅支持 Python 3.12)
  • 深度学习框架:
    • 对于 PyTorch 前端:PyTorch 2.4.0 或更高版本
    • 对于 JAX 前端:JAX 0.5.0 或更高版本

在安装 cuEquivariance 之前,建议先安装并配置好 CUDA 工具包和相应的深度学习框架。这将确保 cuEquivariance 能够正确识别和使用您的 GPU 资源。

2.2 使用 pip 安装

使用 pip 安装是获取 NVIDIA cuEquivariance 最简单、最推荐的方法。根据您的需求,可以选择安装不同的组件。

2.2.1 安装核心组件

核心组件包含所有非机器学习的功能,如果您只需要基本的等变操作而不需要与特定深度学习框架集成,可以只安装这个组件:

pip install cuequivariance

这将安装 cuequivariance 包,提供群表示、等变张量积描述符等基础功能。

2.2.2 安装 PyTorch 前端

如果您计划在 PyTorch 项目中使用 cuEquivariance,需要安装 PyTorch 前端:

pip install cuequivariance-torch

这将安装 cuequivariance_torch 包,提供与 PyTorch 集成的接口和功能。

2.2.3 安装 JAX 前端

如果您计划在 JAX 项目中使用 cuEquivariance,需要安装 JAX 前端:

pip install cuequivariance-jax

这将安装 cuequivariance_jax 包,提供与 JAX 集成的接口和功能。

2.2.4 安装 CUDA 内核

为了获得最佳性能,您还需要安装适合您环境的 CUDA 内核。根据您使用的深度学习框架和 CUDA 版本,选择相应的包:

对于 PyTorch + CUDA 11.x

pip install cuequivariance-ops-torch-cu11

对于 PyTorch + CUDA 12.x

pip install cuequivariance-ops-torch-cu12

对于 JAX + CUDA 12.x

pip install cuequivariance-ops-jax-cu12

请注意,目前 JAX 前端仅支持 CUDA 12.x。

一键安装示例

以下是一些常见安装场景的一键安装命令:

PyTorch + CUDA 11.x 完整安装

pip install cuequivariance cuequivariance-torch cuequivariance-ops-torch-cu11

PyTorch + CUDA 12.x 完整安装

pip install cuequivariance cuequivariance-torch cuequivariance-ops-torch-cu12

JAX + CUDA 12.x 完整安装

pip install cuequivariance cuequivariance-jax cuequivariance-ops-jax-cu12

2.3 从源代码构建

如果您需要最新的功能或者想要修改库的行为,可以从源代码构建 cuEquivariance。以下是从源代码构建的步骤:

1. 克隆仓库

首先,从 GitHub 克隆 cuEquivariance 仓库:

git clone https://github.com/NVIDIA/cuequivariance.git
cd cuequivariance

2. 创建并激活虚拟环境(可选但推荐)

python -m venv venv
source venv/bin/activate  # 在 Linux/macOS 上
# 或
venv\Scripts\activate  # 在 Windows 上

3. 安装构建依赖

pip install -e ".[dev]"

这将安装开发所需的所有依赖项。

4. 构建 CUDA 扩展

对于 PyTorch 扩展:

cd cuequivariance-ops-torch
python setup.py install

对于 JAX 扩展:

cd cuequivariance-ops-jax
python setup.py install

5. 安装前端包

cd cuequivariance-torch  # 或 cuequivariance-jax
pip install -e .

从源代码构建可能需要更多的时间和技术知识,但它提供了最大的灵活性和对库的控制。

2.4 验证安装

安装完成后,建议验证安装是否成功。以下是一些简单的验证方法:

验证核心组件

# 创建一个简单的 Python 脚本 verify_core.py
import cuequivariance as cue

# 创建一个 Irreps 对象
irreps = cue.Irreps("SO3", "1x0 + 1x1")
print(f"成功创建 Irreps 对象: {irreps}")
print(f"维度: {irreps.dim}")

运行脚本:

python verify_core.py

如果安装正确,您应该看到类似以下的输出:

成功创建 Irreps 对象: 1x0 + 1x1
维度: 4

验证 PyTorch 前端

# 创建一个简单的 Python 脚本 verify_torch.py
import torch
import cuequivariance as cue
import cuequivariance_torch as cuet

# 创建一个简单的等变线性层
irreps_in = cue.Irreps("SO3", "10x0 + 5x1")
irreps_out = cue.Irreps("SO3", "5x0 + 3x1")
e = cue.descriptors.linear(irreps_in, irreps_out)

# 创建一个 PyTorch 模块
module = cuet.EquivariantTensorProduct(e, layout=cue.ir_mul, use_fallback=True)

# 创建一些随机输入
batch_size = 2
x = torch.randn(batch_size, irreps_in.dim)
w = torch.randn(1, e.inputs[0].dim)

# 执行前向传播
y = module(w, x)

print(f"输入形状: {x.shape}")
print(f"权重形状: {w.shape}")
print(f"输出形状: {y.shape}")
print(f"预期输出维度: {irreps_out.dim}")

运行脚本:

python verify_torch.py

如果安装正确,您应该看到输出形状与预期维度匹配。

验证 JAX 前端

# 创建一个简单的 Python 脚本 verify_jax.py
import jax
import jax.numpy as jnp
import cuequivariance as cue
import cuequivariance_jax as cuex

# 创建一个简单的等变线性层
irreps_in = cue.Irreps("SO3", "10x0 + 5x1")
irreps_out = cue.Irreps("SO3", "5x0 + 3x1")
e = cue.descriptors.linear(irreps_in, irreps_out)

# 创建一些随机输入
key = jax.random.key(0)
key1, key2 = jax.random.split(key)
batch_size = 2
x = cuex.randn(key1, (batch_size, irreps_in))
w = cuex.randn(key2, e.inputs[0])

# 执行等变张量积
y = cuex.equivariant_tensor_product(e, w, x)

print(f"输入形状: {x.shape}")
print(f"权重形状: {w.shape}")
print(f"输出形状: {y.shape}")
print(f"预期输出维度: {irreps_out.dim}")

运行脚本:

python verify_jax.py

如果安装正确,您应该看到输出形状与预期维度匹配。

验证 GPU 支持

要验证 cuEquivariance 是否正确使用 GPU,可以修改上述脚本,明确将张量放在 GPU 上:

对于 PyTorch

# 在 verify_torch.py 中添加
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"使用设备: {device}")

# 将张量移动到设备上
x = x.to(device)
w = w.to(device)
module = module.to(device)

# 其余代码保持不变

对于 JAX

# 在 verify_jax.py 中添加
print(f"可用设备: {jax.devices()}")

# JAX 会自动使用可用的 GPU

2.5 常见安装问题及解决方案

在安装和使用 cuEquivariance 过程中,可能会遇到一些常见问题。以下是这些问题的解决方案:

1. CUDA 版本不匹配

症状:安装时出现 CUDA 版本不兼容的错误,或者运行时出现 “CUDA error: no kernel image is available for execution” 等错误。

解决方案

  • 确认您系统上安装的 CUDA 版本:nvcc --version
  • 确保安装的 cuEquivariance 包与您的 CUDA 版本匹配
  • 如果需要,可以重新安装匹配的版本:
    pip uninstall cuequivariance-ops-torch-cu11 cuequivariance-ops-torch-cu12
    pip install cuequivariance-ops-torch-cu11  # 或 cu12,取决于您的 CUDA 版本
    

2. PyTorch/JAX 版本不兼容

症状:导入 cuequivariance_torchcuequivariance_jax 时出现版本不兼容错误。

解决方案

  • 检查您的 PyTorch 或 JAX 版本:
    import torch; print(torch.__version__)  # 对于 PyTorch
    import jax; print(jax.__version__)      # 对于 JAX
    
  • 确保 PyTorch 版本 ≥ 2.4.0 或 JAX 版本 ≥ 0.5.0
  • 如果需要,升级您的深度学习框架:
    pip install --upgrade torch
    # 或
    pip install --upgrade jax
    

3. 找不到 CUDA 设备

症状:代码运行在 CPU 上而不是 GPU 上,或者出现 “CUDA not available” 错误。

解决方案

  • 确认 CUDA 设备是否可见:
    nvidia-smi
    
  • 检查 CUDA 环境变量:
    echo $CUDA_VISIBLE_DEVICES
    
  • 确保 PyTorch 或 JAX 能够检测到 GPU:
    import torch; print(torch.cuda.is_available())  # 对于 PyTorch
    import jax; print(jax.devices())                # 对于 JAX
    
  • 如果使用容器或虚拟环境,确保 CUDA 设备已正确传递给容器或环境

4. 内存错误

症状:运行时出现 “CUDA out of memory” 或类似错误。

解决方案

  • 减小批量大小或模型大小
  • 使用梯度累积技术
  • 尝试使用混合精度训练(对于 PyTorch):
    from torch.cuda.amp import autocast
    
    with autocast():
        output = model(input)
    
  • 监控 GPU 内存使用情况:
    nvidia-smi -l 1  # 每秒更新一次
    

5. 编译错误(从源代码构建时)

症状:从源代码构建时出现编译错误。

解决方案

  • 确保已安装所有必要的构建工具:
    # 在 Ubuntu 上
    sudo apt-get install build-essential
    
  • 确保 CUDA 工具链配置正确:
    echo $PATH  # 应包含 CUDA bin 目录
    echo $LD_LIBRARY_PATH  # 应包含 CUDA lib 目录
    
  • 尝试清理构建目录后重新构建:
    rm -rf build/
    python setup.py clean
    python setup.py install
    

6. 导入错误

症状:导入模块时出现 “ImportError” 或 “ModuleNotFoundError”。

解决方案

  • 确认所有必要的包都已安装:
    pip list | grep cuequivariance
    
  • 检查 Python 路径:
    import sys; print(sys.path)
    
  • 尝试重新安装包:
    pip uninstall cuequivariance cuequivariance-torch cuequivariance-jax
    pip install cuequivariance cuequivariance-torch cuequivariance-jax
    

7. aarch64 架构特定问题

症状:在 ARM64 架构(如 NVIDIA Jetson 或 Apple M1/M2)上安装失败。

解决方案

  • 确认您使用的是 Python 3.12(目前 aarch64 仅支持此版本)
  • 确保安装的是适用于 aarch64 的包
  • 对于 Apple Silicon,可能需要使用 Rosetta 2 运行 x86_64 版本

通过以上步骤,您应该能够成功安装和配置 NVIDIA cuEquivariance,并解决在此过程中可能遇到的常见问题。如果您遇到未在此列出的问题,建议查阅官方文档或在 GitHub 仓库中提交 issue。

NVIDIA cuEquivariance 详细教程:群论基础与表示理论

3.1 什么是群(Group)

群是数学中的一个基本代数结构,它在等变神经网络的理论基础中扮演着核心角色。简单来说,群是一组元素和一个二元运算的集合,满足特定的数学性质。这些性质使得群成为描述对称性和变换的强大工具。

群的形式定义

从数学角度看,群是一个集合 G 和一个二元运算 ·(通常称为群乘法),满足以下四个公理:

  1. 闭合性:对于群 G 中的任意两个元素 a 和 b,它们的乘积 a·b 也是群 G 中的元素。
  2. 结合律:对于群 G 中的任意三个元素 a、b 和 c,(a·b)·c = a·(b·c)。
  3. 单位元:群 G 中存在一个元素 e(称为单位元),对于群中的任意元素 a,都有 e·a = a·e = a。
  4. 逆元:对于群 G 中的每个元素 a,都存在一个元素 a^(-1)(称为 a 的逆元),使得 a·a^(-1) = a^(-1)·a = e。

群的直观理解

虽然群的数学定义看起来有些抽象,但我们可以通过一些直观的例子来理解它:

整数加法群:考虑所有整数的集合 Z 和加法运算 +。这构成了一个群,因为:

  • 任意两个整数的和仍然是整数(闭合性)
  • 加法满足结合律:(a + b) + c = a + (b + c)
  • 0 是加法的单位元:a + 0 = 0 + a = a
  • 每个整数 a 都有一个加法逆元 -a:a + (-a) = (-a) + a = 0

旋转群:考虑二维平面上所有旋转变换的集合,其中运算是变换的组合(先进行一个旋转,再进行另一个旋转)。这也构成了一个群:

  • 两个旋转的组合仍然是一个旋转(闭合性)
  • 旋转的组合满足结合律
  • 旋转 0 度是单位元
  • 每个旋转都有一个逆旋转(例如,顺时针旋转 θ 度的逆是逆时针旋转 θ 度)

群在等变神经网络中的重要性

在等变神经网络的背景下,群用于描述我们希望网络尊重的对称性或变换。例如:

  • 如果我们处理的是图像数据,可能关心的是旋转、平移或缩放等变换,这些可以用相应的群来描述。
  • 如果我们处理的是分子数据,可能关心的是三维空间中的旋转和反射,这可以用 O(3) 群来描述。

通过将这些对称性编码到网络架构中,我们可以确保网络的预测在这些变换下表现出适当的行为,从而提高模型的数据效率和泛化能力。

3.2 群表示(Group Representation)的概念

群表示是将抽象群元素映射到具体数学对象(通常是矩阵)的方法,使得群的结构和性质在这些对象上得到保持。表示理论是连接抽象群论和实际应用的桥梁,在等变神经网络中起着关键作用。

群表示的形式定义

从数学角度看,群 G 的表示是一个映射 ρ: G → GL(V),其中 GL(V) 是向量空间 V 上的可逆线性变换群(或等价地,可逆矩阵群)。这个映射需要满足:

  1. ρ(g₁·g₂) = ρ(g₁)·ρ(g₂),对于群 G 中的任意元素 g₁ 和 g₂
  2. ρ(e) = I,其中 e 是群 G 的单位元,I 是单位矩阵

简单来说,表示将群的乘法结构映射到矩阵的乘法结构,保持了群的代数性质。

群表示的直观理解

想象我们有一个抽象的群,比如三维空间中的旋转群 SO(3)。这个群的元素是抽象的旋转操作,但我们希望用具体的数学对象(如矩阵)来表示它们,以便进行计算。

例如,我们可以将 SO(3) 中的旋转映射到 3×3 的正交矩阵,使得:

  • 每个旋转操作对应一个特定的矩阵
  • 两个旋转的组合对应于相应矩阵的乘积
  • 单位旋转(不旋转)对应于单位矩阵

这就是 SO(3) 的一个表示。通过这种表示,我们可以用矩阵运算来模拟旋转操作,从而在计算机中实现这些操作。

表示的维数

表示的维数是指表示矩阵的维数,也就是向量空间 V 的维数。例如:

  • 1 维表示将群元素映射到 1×1 矩阵(标量)
  • 3 维表示将群元素映射到 3×3 矩阵

不同维数的表示可以捕获群的不同方面,并在不同的应用中发挥作用。

表示在等变神经网络中的应用

在等变神经网络中,我们使用群表示来:

  1. 定义特征的变换行为:通过指定特征向量属于哪个表示,我们定义了它们在群变换下如何变化。
  2. 构建等变层:等变层需要在不同表示之间进行映射,同时保持等变性。
  3. 实现等变操作:如卷积、注意力机制等,需要根据表示的性质进行适当修改。

例如,在处理 3D 点云数据时,我们可能使用 SO(3) 的不同表示来表示不同类型的特征:

  • 标量特征(如点的密度)使用 1 维表示
  • 向量特征(如点的法向量)使用 3 维表示
  • 更高阶特征可能使用更高维的表示

通过正确处理这些表示,我们可以构建对旋转等变的神经网络,使其在输入旋转时产生一致的预测。

3.3 不可约表示(Irreducible Representation)

不可约表示(简称 irrep)是表示理论中的基本构件,类似于向量空间中的基向量。理解不可约表示对于掌握等变神经网络至关重要,因为它们是构建等变层的基础。

不可约表示的定义

从数学角度看,群 G 的不可约表示是指不能被分解为更小表示的直和的表示。更具体地说,如果表示 ρ 作用的向量空间 V 没有非平凡的不变子空间(除了 {0} 和 V 本身之外),那么 ρ 就是不可约的。

这个定义可能看起来有些抽象,但其核心思想是:不可约表示是最"基本"的表示,不能被进一步简化。

不可约表示的重要性

不可约表示之所以重要,是因为:

  1. 完备性:任何表示都可以分解为不可约表示的直和。这类似于任何向量都可以表示为基向量的线性组合。
  2. 正交性:不同不可约表示之间满足一定的正交关系,使得它们在数学上易于处理。
  3. 物理意义:在物理学中,不可约表示通常对应于具有明确物理意义的量,如标量、向量、张量等。

常见群的不可约表示

SO(3)(三维旋转群)的不可约表示

SO(3) 的不可约表示由一个非负整数 l 标记,通常记为 D^l。第 l 个不可约表示的维数是 2l+1。一些常见的例子:

  • l=0:1 维表示(标量表示),对应于在旋转下不变的量
  • l=1:3 维表示(向量表示),对应于在旋转下变换的三维向量
  • l=2:5 维表示,对应于二阶对称无迹张量

O(3)(三维旋转和反射群)的不可约表示

O(3) 可以看作是 SO(3) 和 Z₂(反演群)的半直积。其不可约表示由一对 (l, p) 标记,其中 l 是非负整数,p 是 +1(偶宇称)或 -1(奇宇称)。常见的例子:

  • (0, +1):标量表示,在旋转和反射下都不变
  • (1, -1):向量表示,在旋转下变换为向量,在反射下变号
  • (0, -1):伪标量表示,在旋转下不变,在反射下变号

在 cuEquivariance 中使用不可约表示

在 cuEquivariance 中,我们使用 Irreps 类来描述不可约表示的集合。例如:

# 导入必要的库
import cuequivariance as cue

# 创建一个包含 SO(3) 不可约表示的集合
irreps = cue.Irreps("SO3", "1x0 + 3x1 + 2x2")

这个例子创建了一个表示,包含:

  • 1 个 l=0 的不可约表示(1 维)
  • 3 个 l=1 的不可约表示(每个 3 维)
  • 2 个 l=2 的不可约表示(每个 5 维)

总维数为:1×1 + 3×3 + 2×5 = 1 + 9 + 10 = 20。

对于 O(3) 群,我们可以指定宇称:

# 创建一个包含 O(3) 不可约表示的集合
irreps = cue.Irreps("O3", "1x0e + 3x1o + 2x2e")

这里,“e” 表示偶宇称 (+1),“o” 表示奇宇称 (-1)。

不可约表示的直和

在实际应用中,我们通常使用不可约表示的直和来表示复杂的特征。直和简单地将多个不可约表示组合在一起,形成一个更大的表示。

例如,考虑一个特征向量,其中前 5 个元素是标量(l=0),后 15 个元素可以组织为 5 个 3 维向量(l=1)。我们可以用以下方式表示这个特征:

irreps = cue.Irreps("SO3", "5x0 + 5x1")

这个表示的总维数是 5×1 + 5×3 = 5 + 15 = 20。

通过这种方式,我们可以灵活地组合不同类型的特征,同时保持它们在群变换下的正确变换行为。

3.4 SO(3) 群及其不可约表示

SO(3) 是三维空间中的特殊正交群,也称为三维旋转群。它在物理学、计算机图形学和等变神经网络中有广泛的应用。理解 SO(3) 及其不可约表示对于构建处理三维数据的等变网络至关重要。

SO(3) 群的定义

SO(3) 是所有保持原点不变且保持距离和方向的三维空间变换的集合。形式上,它是所有满足以下条件的 3×3 实矩阵 R 的集合:

  1. R^T R = I(正交性)
  2. det® = 1(特殊性,保持方向)

这些条件确保了 SO(3) 中的变换只进行旋转,不进行反射或缩放。

SO(3) 的参数化

SO(3) 可以用多种方式参数化,包括:

  1. 欧拉角:使用三个角度(通常称为 roll、pitch、yaw)来描述旋转。
  2. 轴角表示:使用一个单位向量(旋转轴)和一个角度来描述旋转。
  3. 四元数:使用四个实数组成的四元数来描述旋转,避免了欧拉角的万向锁问题。
  4. 旋转矩阵:直接使用 3×3 正交矩阵来描述旋转。

在 cuEquivariance 中,旋转通常使用旋转矩阵表示,但库也提供了在不同表示之间转换的功能。

SO(3) 的不可约表示

SO(3) 的不可约表示由一个非负整数 l 标记,通常记为 D^l。第 l 个不可约表示的维数是 2l+1。这些不可约表示有着深刻的物理意义:

l=0(标量表示)

  • 维数:2×0+1 = 1
  • 物理意义:在旋转下不变的量,如质量、电荷、温度等
  • 变换行为:D^0® = 1,对于任何旋转 R

l=1(向量表示)

  • 维数:2×1+1 = 3
  • 物理意义:三维空间中的向量,如位置、速度、力等
  • 变换行为:D^1® = R,即旋转矩阵本身

l=2(二阶张量表示)

  • 维数:2×2+1 = 5
  • 物理意义:二阶对称无迹张量,如应力张量的无迹部分
  • 变换行为:更复杂,但可以用球谐函数表示

更高阶表示(l>2)

  • 维数:2l+1
  • 物理意义:高阶多极矩、高阶球谐函数等
  • 变换行为:可以用 Wigner D-矩阵表示

球谐函数与 SO(3) 不可约表示

球谐函数 Y_l^m(θ,φ) 与 SO(3) 的不可约表示密切相关。事实上,对于给定的 l,2l+1 个球谐函数 Y_l^m(m 从 -l 到 l)形成了 SO(3) 第 l 个不可约表示的一个基。

在 cuEquivariance 中,我们可以使用球谐函数来构建等变特征:

# 创建一个球谐函数描述符
import cuequivariance as cue
sh_descriptor = cue.descriptors.spherical_harmonics(cue.SO3(1), [0, 1, 2])

这个描述符将输入向量映射到 l=0、l=1 和 l=2 的球谐函数,生成一个等变特征。

SO(3) 表示的张量积

两个 SO(3) 不可约表示的张量积可以分解为不可约表示的直和。具体来说,D^l₁ ⊗ D^l₂ 可以分解为:

D^l₁ ⊗ D^l₂ = ⊕_{l=|l₁-l₂|}^{l₁+l₂} D^l

例如,两个向量表示(l=1)的张量积可以分解为:

D^1 ⊗ D^1 = D^0 ⊕ D^1 ⊕ D^2

这对应于两个三维向量的张量积(9 维)分解为一个标量(1 维)、一个向量(3 维)和一个二阶对称无迹张量(5 维)。

在 cuEquivariance 中,我们可以使用等变张量积来实现这种分解:

# 创建两个 SO(3) 不可约表示
irreps1 = cue.Irreps("SO3", "1x1")  # 一个向量
irreps2 = cue.Irreps("SO3", "1x1")  # 另一个向量

# 创建它们的等变张量积
tensor_product = cue.descriptors.tensor_product(irreps1, irreps2)

这个张量积将自动按照 SO(3) 表示的分解规则进行分解。

3.5 O(3) 群及其不可约表示

O(3) 是三维空间中的正交群,包括旋转和反射。它是 SO(3) 的扩展,增加了空间反演操作。O(3) 在处理具有反射对称性的物理系统中特别重要,如分子、晶体结构等。

O(3) 群的定义

O(3) 是所有保持原点不变且保持距离的三维空间变换的集合。形式上,它是所有满足条件 R^T R = I(正交性)的 3×3 实矩阵 R 的集合。

与 SO(3) 不同,O(3) 不要求 det® = 1,因此它包括:

  • det® = 1 的矩阵,对应于旋转(SO(3) 的元素)
  • det® = -1 的矩阵,对应于旋转后跟空间反演

O(3) 的结构

O(3) 可以看作是 SO(3) 和 Z₂(反演群)的半直积:O(3) = SO(3) ⋊ Z₂。

Z₂ 是只有两个元素的群:单位元 e 和反演元素 P。反演元素满足 P² = e,对应于空间坐标的反转:(x, y, z) → (-x, -y, -z)。

这种结构意味着 O(3) 中的任何元素都可以唯一地表示为一个 SO(3) 元素(旋转)和一个 Z₂ 元素(是否反演)的组合。

O(3) 的不可约表示

由于 O(3) = SO(3) ⋊ Z₂ 的结构,O(3) 的不可约表示可以从 SO(3) 和 Z₂ 的不可约表示构建。

O(3) 的不可约表示由一对 (l, p) 标记,其中:

  • l 是非负整数,对应于 SO(3) 的不可约表示
  • p 是 +1(偶宇称)或 -1(奇宇称),对应于 Z₂ 的表示

第 (l, p) 个不可约表示的维数仍然是 2l+1,与相应的 SO(3) 表示相同。

O(3) 不可约表示的物理意义

O(3) 不可约表示的物理意义与宇称(parity)密切相关,宇称描述了物理量在空间反演下的变换行为:

标量表示 (0, +1)

  • 在旋转和反射下都不变的量
  • 例如:质量、电荷、温度

伪标量表示 (0, -1)

  • 在旋转下不变,但在反射下变号的量
  • 例如:磁通量、手性度量

向量表示 (1, -1)

  • 在旋转下变换为向量,在反射下变号
  • 例如:位置、速度、电场

伪向量表示 (1, +1)

  • 在旋转下变换为向量,在反射下不变号
  • 例如:角动量、磁场

在 cuEquivariance 中使用 O(3) 表示

在 cuEquivariance 中,我们可以使用 “e”(偶宇称)和 “o”(奇宇称)后缀来指定 O(3) 不可约表示:

# 创建一个包含 O(3) 不可约表示的集合
import cuequivariance as cue
irreps = cue.Irreps("O3", "1x0e + 3x1o + 2x2e")

这个例子创建了一个表示,包含:

  • 1 个 (0, +1) 不可约表示(偶宇称标量)
  • 3 个 (1, -1) 不可约表示(奇宇称向量)
  • 2 个 (2, +1) 不可约表示(偶宇称二阶张量)

O(3) 表示的张量积

O(3) 不可约表示的张量积遵循与 SO(3) 类似的规则,但需要考虑宇称。两个 O(3) 不可约表示 (l₁, p₁) 和 (l₂, p₂) 的张量积可以分解为:

(l₁, p₁) ⊗ (l₂, p₂) = ⊕_{l=|l₁-l₂|}^{l₁+l₂} (l, p₁·p₂)

注意宇称相乘:偶 × 偶 = 偶,偶 × 奇 = 奇,奇 × 奇 = 偶。

在 cuEquivariance 中,等变张量积会自动处理这些宇称规则:

# 创建两个 O(3) 不可约表示
irreps1 = cue.Irreps("O3", "1x1o")  # 一个奇宇称向量
irreps2 = cue.Irreps("O3", "1x1o")  # 另一个奇宇称向量

# 创建它们的等变张量积
tensor_product = cue.descriptors.tensor_product(irreps1, irreps2)
# 结果将包含偶宇称表示,因为奇 × 奇 = 偶

通过正确处理 O(3) 表示及其张量积,我们可以构建既尊重旋转对称性又尊重反射对称性的等变神经网络。

3.6 实例:应力张量的不可约表示分解

应力张量是材料力学中的一个基本概念,它描述了材料内部一点处的应力状态。通过将应力张量分解为不可约表示,我们可以更好地理解其物理意义,并在等变神经网络中正确处理它。

应力张量的基本概念

应力张量 σ 是一个二阶张量,通常表示为一个 3×3 矩阵:

σ = [ σ₁₁ σ₁₂ σ₁₃ ]
    [ σ₂₁ σ₂₂ σ₂₃ ]
    [ σ₃₁ σ₃₂ σ₃₃ ]

其中,σᵢⱼ 表示在 i 方向上的面上作用的 j 方向的应力分量。

应力张量通常是对称的(σᵢⱼ = σⱼᵢ),这反映了角动量守恒。因此,它有 6 个独立分量,而不是 9 个。

应力张量在旋转下的变换

当坐标系统旋转时,应力张量按以下方式变换:

σ’ = R σ R^T

其中 R 是旋转矩阵,σ 是原始应力张量,σ’ 是旋转后的应力张量。

这种变换行为表明应力张量不是 SO(3) 的不可约表示,而是一个可约表示。我们可以将其分解为不可约表示的直和。

应力张量的不可约分解

对称的 3×3 应力张量可以分解为两个不可约表示的直和:

σ = σ₀ I + σ₂

其中:

  • σ₀ = (1/3) Tr(σ) 是应力张量的迹(除以 3),对应于静水压力。这是一个标量,属于 l=0 不可约表示。
  • σ₂ = σ - σ₀ I 是应力张量的无迹部分,对应于偏应力。这是一个对称无迹张量,属于 l=2 不可约表示。

在 SO(3) 不可约表示的语言中,这对应于:

3×3 对称张量 = D^0 ⊕ D^2

其中 D^0 是 1 维的标量表示,D^2 是 5 维的二阶对称无迹张量表示。

在 cuEquivariance 中实现应力张量分解

我们可以使用 cuEquivariance 来实现应力张量的不可约分解:

import cuequivariance as cue
import torch
import cuequivariance_torch as cuet
import numpy as np

# 创建一个随机对称应力张量
def create_symmetric_stress_tensor():
    # 创建一个随机 3x3 矩阵
    stress = torch.randn(3, 3)
    # 使其对称
    stress = (stress + stress.T) / 2
    return stress

# 将应力张量分解为不可约表示
def decompose_stress_tensor(stress):
    # 计算迹(标量部分)
    trace = torch.trace(stress)
    scalar_part = trace / 3
    
    # 计算无迹部分(l=2 表示)
    identity = torch.eye(3)
    traceless_part = stress - scalar_part * identity
    
    return scalar_part, traceless_part

# 创建一个等变网络来处理应力张量
def create_stress_tensor_network():
    # 定义输入表示:1x0 (标量部分) + 1x2 (无迹部分)
    irreps_in = cue.Irreps("SO3", "1x0 + 1x2")
    
    # 定义输出表示
    irreps_out = cue.Irreps("SO3", "1x0")  # 例如,预测一个标量属性
    
    # 创建一个等变线性层
    e = cue.descriptors.linear(irreps_in, irreps_out)
    module = cuet.EquivariantTensorProduct(e, layout=cue.ir_mul, use_fallback=True)
    
    return module

# 主函数
def main():
    # 创建一个随机应力张量
    stress = create_symmetric_stress_tensor()
    print("原始应力张量:")
    print(stress)
    
    # 分解为不可约表示
    scalar_part, traceless_part = decompose_stress_tensor(stress)
    print("\n标量部分 (l=0):")
    print(scalar_part)
    print("\n无迹部分 (l=2):")
    print(traceless_part)
    
    # 验证分解的正确性
    reconstructed = scalar_part * torch.eye(3) + traceless_part
    print("\n重构的应力张量:")
    print(reconstructed)
    print("\n重构误差:")
    print(torch.norm(stress - reconstructed))
    
    # 创建一个处理应力张量的等变网络
    network = create_stress_tensor_network()
    
    # 将分解后的表示转换为网络输入格式
    # 注意:这里需要将 3x3 的无迹部分转换为 5 维向量
    # 这涉及到球谐函数的细节,这里简化处理
    scalar_input = scalar_part.unsqueeze(0)
    
    # 简化:使用随机 5 维向量代替实际的转换
    # 在实际应用中,应使用适当的转换函数
    traceless_input = torch.randn(1, 5)
    
    # 组合输入
    network_input = torch.cat([scalar_input, traceless_input], dim=1)
    
    # 创建权重
    weights = torch.randn(1, network.weight_numel)
    
    # 前向传播
    output = network(weights, network_input)
    print("\n网络输出 (标量):")
    print(output)

if __name__ == "__main__":
    main()

这个例子展示了如何:

  1. 创建一个对称应力张量
  2. 将其分解为标量部分(l=0)和无迹部分(l=2)
  3. 验证分解的正确性
  4. 创建一个等变网络来处理分解后的应力张量

应力张量分解的物理意义

应力张量的不可约分解具有明确的物理意义:

  1. 标量部分(l=0)

    • 表示静水压力或平均正应力
    • 导致体积变化但不导致形状变化
    • 在所有方向上均匀作用
  2. 无迹部分(l=2)

    • 表示偏应力或剪应力
    • 导致形状变化但不导致体积变化
    • 在不同方向上作用不同

这种分解使我们能够分别研究材料的体积变化和形状变化,这在材料科学和固体力学中非常有用。

在等变神经网络中的应用

在等变神经网络中,通过将应力张量分解为不可约表示,我们可以:

  1. 正确处理其在旋转下的变换行为
  2. 分别处理不同物理意义的部分
  3. 构建尊重物理对称性的模型

例如,在预测材料的弹性响应时,我们可以使用等变网络来处理应力张量,确保预测结果在坐标系旋转下保持一致。

通过这个实例,我们看到了群表示理论如何帮助我们理解物理量的本质,以及如何在等变神经网络中正确处理它们。这种方法不仅适用于应力张量,也适用于其他物理张量,如应变张量、弹性张量等。

NVIDIA cuEquivariance 详细教程:核心组件

4.1 Irreps 类详解

Irreps 类是 cuEquivariance 库的核心组件之一,它用于描述群表示中存在的不可约表示(irreps)及其多重性。理解和掌握 Irreps 类的使用对于构建等变神经网络至关重要。

4.1.1 创建和使用 Irreps 对象

Irreps 类提供了一种简洁而强大的方式来描述不可约表示的集合。让我们从基本的创建和使用开始:

# 导入必要的库
import cuequivariance as cue

# 创建一个简单的 Irreps 对象
irreps = cue.Irreps("SO3", "10x0 + 5x1")

# 查看 Irreps 对象的基本信息
print(f"Irreps: {irreps}")  # 输出: 10x0 + 5x1
print(f"维度: {irreps.dim}")  # 输出: 10*1 + 5*3 = 25
print(f"不可约表示数量: {len(irreps)}")  # 输出: 2

在这个例子中:

  • 我们创建了一个 SO(3) 群的表示
  • 这个表示包含 10 个 l=0 的不可约表示(标量)和 5 个 l=1 的不可约表示(向量)
  • 总维度是 101 + 53 = 25

Irreps 对象是可迭代的,我们可以遍历其中的每个不可约表示:

# 遍历 Irreps 对象中的每个不可约表示
for mul, ir in irreps:
    print(f"多重性: {mul}, 不可约表示: {ir}, 维度: {ir.dim}")

输出将是:

多重性: 10, 不可约表示: 0, 维度: 1
多重性: 5, 不可约表示: 1, 维度: 3

我们还可以通过索引访问特定的不可约表示:

# 访问第一个不可约表示
mul, ir = irreps[0]
print(f"第一个不可约表示: 多重性={mul}, 类型={ir}")

# 访问第二个不可约表示
mul, ir = irreps[1]
print(f"第二个不可约表示: 多重性={mul}, 类型={ir}")

4.1.2 Irreps 字符串表示法

Irreps 类使用一种简洁的字符串表示法来描述不可约表示的集合。这种表示法的基本格式是:

[多重性]x[不可约表示标签] + [多重性]x[不可约表示标签] + ...

其中:

  • 多重性:表示该不可约表示出现的次数
  • 不可约表示标签:表示不可约表示的类型,具体格式取决于群

对于不同的群,不可约表示标签的格式如下:

SO(3) 群

  • 标签格式:l
  • 其中 l 是非负整数,表示角动量量子数
  • 例如:"3x0 + 2x1 + 1x2"

O(3) 群

  • 标签格式:lelo
  • 其中 l 是非负整数,e 表示偶宇称,o 表示奇宇称
  • 例如:"3x0e + 2x1o + 1x2e"

SU(2) 群

  • 标签格式:jj/2
  • 其中 j 是非负整数或半整数,表示自旋量子数
  • 例如:"2x0 + 3x1/2 + 1x1"

为了简化表示,当多重性为 1 时,可以省略 1x

# 这两种表示是等价的
irreps1 = cue.Irreps("SO3", "1x0 + 1x1 + 1x2")
irreps2 = cue.Irreps("SO3", "0 + 1 + 2")
print(irreps1 == irreps2)  # 输出: True

我们还可以使用乘法来简化表示:

# 这两种表示是等价的
irreps1 = cue.Irreps("O3", "5x0e + 5x0e + 5x0e")
irreps2 = cue.Irreps("O3", "3x5x0e")
print(irreps1 == irreps2)  # 输出: True

4.1.3 设置默认群

如果我们在一个项目中主要使用一种群,可以使用 cue.assume 上下文管理器来设置默认群,避免每次都指定群:

# 不使用默认群
irreps1 = cue.Irreps("SO3", "0 + 1 + 2")

# 使用默认群
with cue.assume(cue.SO3):
    irreps2 = cue.Irreps("0 + 1 + 2")
    
print(irreps1 == irreps2)  # 输出: True

我们也可以全局设置默认群:

# 全局设置默认群
cue.set_default_group(cue.O3)

# 现在可以省略群参数
irreps = cue.Irreps("0e + 1o + 2e")

Irreps 类的高级操作

Irreps 类还提供了一些高级操作,用于处理和转换不可约表示:

1. 合并相同类型的不可约表示

# 合并相同类型的不可约表示
irreps = cue.Irreps("SO3", "2x0 + 3x0 + 1x1 + 4x1")
irreps_simplified = irreps.simplify()
print(irreps_simplified)  # 输出: 5x0 + 5x1

2. 获取表示的总维度

irreps = cue.Irreps("SO3", "10x0 + 5x1 + 2x2")
total_dim = irreps.dim
print(f"总维度: {total_dim}")  # 输出: 10*1 + 5*3 + 2*5 = 10 + 15 + 10 = 35

3. 检查两个 Irreps 是否兼容

irreps1 = cue.Irreps("SO3", "10x0 + 5x1")
irreps2 = cue.Irreps("SO3", "10x0 + 5x1")
irreps3 = cue.Irreps("SO3", "5x0 + 10x1")

print(irreps1 == irreps2)  # 输出: True
print(irreps1 == irreps3)  # 输出: False

4. 创建 Irreps 的子集

irreps = cue.Irreps("SO3", "3x0 + 2x1 + 1x2")
# 获取前两个不可约表示
sub_irreps = irreps[:2]
print(sub_irreps)  # 输出: 3x0 + 2x1

通过掌握 Irreps 类的使用,我们可以灵活地描述和操作等变神经网络中的特征表示,为构建复杂的等变模型奠定基础。

4.2 数据布局

在处理等变神经网络时,数据的组织方式(布局)对于性能和可用性有重要影响。cuEquivariance 提供了灵活的数据布局选项,使用户可以根据自己的需求选择最合适的布局。

4.2.1 (ir, mul) 布局

(ir, mul) 布局将不可约表示作为最外层维度,多重性作为内层维度。这种布局在某些情况下更自然,特别是当我们想要按不可约表示类型处理数据时。

在这种布局中,数据的组织方式如下:

[batch_size, ir_1, mul_1, ir_2, mul_2, ..., ir_n, mul_n]

例如,对于表示 "2x0 + 3x1"

  • 第一部分是 2 个标量(l=0),维度为 2×1=2
  • 第二部分是 3 个向量(l=1),维度为 3×3=9
  • 总维度为 2+9=11

在 (ir, mul) 布局中,这 11 个值按照先不可约表示后多重性的顺序排列:

  • 前 2 个值是 2 个标量
  • 后 9 个值是 3 个向量,每个向量有 3 个分量

这种布局在 cuEquivariance 中通过 cue.ir_mul 常量表示:

import cuequivariance as cue
import torch

# 创建一个表示
irreps = cue.Irreps("SO3", "2x0 + 3x1")

# 创建一个随机张量,使用 (ir, mul) 布局
batch_size = 4
x = torch.randn(batch_size, irreps.dim)  # 形状为 [4, 11]

# 指定使用 (ir, mul) 布局
layout = cue.ir_mul

# 在等变操作中使用这个布局
# ...

4.2.2 (mul, ir) 布局

(mul, ir) 布局将多重性作为最外层维度,不可约表示作为内层维度。这种布局在某些框架(如 e3nn)中更常用,特别是当我们想要按多重性(通常对应于通道)处理数据时。

在这种布局中,数据的组织方式如下:

[batch_size, mul_1, ir_1, mul_2, ir_2, ..., mul_n, ir_n]

对于同样的表示 "2x0 + 3x1",在 (mul, ir) 布局中:

  • 前 2 个值是 2 个标量,每个标量有 1 个分量
  • 后 9 个值是 3 个向量,每个向量有 3 个分量

但它们的排列顺序是按照先多重性后不可约表示:

  • 第一个标量的 1 个分量
  • 第二个标量的 1 个分量
  • 第一个向量的 3 个分量
  • 第二个向量的 3 个分量
  • 第三个向量的 3 个分量

这种布局在 cuEquivariance 中通过 cue.mul_ir 常量表示:

import cuequivariance as cue
import torch

# 创建一个表示
irreps = cue.Irreps("SO3", "2x0 + 3x1")

# 创建一个随机张量,使用 (mul, ir) 布局
batch_size = 4
x = torch.randn(batch_size, irreps.dim)  # 形状为 [4, 11]

# 指定使用 (mul, ir) 布局
layout = cue.mul_ir

# 在等变操作中使用这个布局
# ...

4.2.3 布局转换

cuEquivariance 提供了在不同布局之间转换的功能。这在与其他使用不同布局的库集成时特别有用。

手动转换布局

import cuequivariance as cue
import torch

# 创建一个表示
irreps = cue.Irreps("SO3", "2x0 + 3x1")

# 创建一个使用 (ir, mul) 布局的张量
batch_size = 4
x_ir_mul = torch.randn(batch_size, irreps.dim)

# 将张量从 (ir, mul) 布局转换为 (mul, ir) 布局
x_mul_ir = cue.convert_layout(x_ir_mul, irreps, cue.ir_mul, cue.mul_ir)

# 将张量从 (mul, ir) 布局转换回 (ir, mul) 布局
x_ir_mul_again = cue.convert_layout(x_mul_ir, irreps, cue.mul_ir, cue.ir_mul)

# 验证转换是否正确
print(torch.allclose(x_ir_mul, x_ir_mul_again))  # 输出: True

在等变操作中自动处理布局

cuEquivariance 的等变操作(如 EquivariantTensorProduct)可以自动处理布局转换。我们只需指定输入和期望的输出布局:

import cuequivariance as cue
import cuequivariance_torch as cuet
import torch

# 创建一个等变线性层描述符
irreps_in = cue.Irreps("SO3", "10x0 + 5x1")
irreps_out = cue.Irreps("SO3", "5x0 + 3x1")
e = cue.descriptors.linear(irreps_in, irreps_out)

# 创建一个使用 (mul, ir) 布局的等变模块
module = cuet.EquivariantTensorProduct(e, layout=cue.mul_ir, use_fallback=True)

# 创建输入张量(假设使用 (ir, mul) 布局)
batch_size = 2
x = torch.randn(batch_size, irreps_in.dim)  # (ir, mul) 布局
w = torch.randn(1, e.inputs[0].dim)

# 模块会自动将输入从 (ir, mul) 转换为 (mul, ir),然后执行操作,
# 最后将输出转换回 (mul, ir) 布局
y = module(w, x)  # y 使用 (mul, ir) 布局

布局选择的考虑因素

在选择数据布局时,需要考虑以下因素:

  1. 与其他库的兼容性:如果您正在与其他使用特定布局的库(如 e3nn)集成,可能需要选择匹配的布局。

  2. 性能考虑:在某些情况下,一种布局可能比另一种布局更高效,特别是在涉及特定硬件优化时。

  3. 代码可读性和维护性:选择一种在您的应用中最自然和直观的布局,可以提高代码的可读性和维护性。

  4. 内存访问模式:不同的布局会导致不同的内存访问模式,这可能影响缓存效率和整体性能。

在 cuEquivariance 中,默认布局是 cue.ir_mul,但库提供了灵活的选项来使用和转换不同的布局,以适应各种需求。

4.3 等变张量积

等变张量积是构建等变神经网络的核心操作,它允许我们在保持等变性的同时组合不同的表示。cuEquivariance 提供了强大的工具来定义和执行等变张量积。

4.3.1 EquivariantTensorProduct 类

EquivariantTensorProduct 类是 cuEquivariance 中用于描述等变张量积的主要类。它封装了输入和输出表示的信息,以及如何在它们之间执行等变映射的规则。

创建 EquivariantTensorProduct 对象

import cuequivariance as cue

# 定义输入和输出表示
irreps_in1 = cue.Irreps("SO3", "10x0 + 5x1")  # 第一个输入
irreps_in2 = cue.Irreps("SO3", "3x0 + 2x1")   # 第二个输入
irreps_out = cue.Irreps("SO3", "5x0 + 3x1")   # 输出

# 创建等变张量积描述符
etp = cue.EquivariantTensorProduct(
    irreps_in=[irreps_in1, irreps_in2],  # 输入表示列表
    irreps_out=irreps_out,               # 输出表示
    # 其他参数...
)

# 查看等变张量积的基本信息
print(f"输入表示: {etp.irreps_in}")
print(f"输出表示: {etp.irreps_out}")
print(f"参数数量: {etp.weight_numel}")

EquivariantTensorProduct 的主要属性

  • irreps_in:输入表示的列表
  • irreps_out:输出表示
  • weight_numel:权重参数的数量
  • ds:底层的分段张量积描述符列表

在 PyTorch 中使用 EquivariantTensorProduct

import cuequivariance as cue
import cuequivariance_torch as cuet
import torch

# 创建等变张量积描述符
irreps_in1 = cue.Irreps("SO3", "10x0 + 5x1")
irreps_in2 = cue.Irreps("SO3", "3x0 + 2x1")
irreps_out = cue.Irreps("SO3", "5x0 + 3x1")
etp = cue.descriptors.tensor_product(irreps_in1, irreps_in2, irreps_out)

# 创建 PyTorch 模块
module = cuet.EquivariantTensorProduct(etp, layout=cue.ir_mul, use_fallback=True)

# 创建输入张量
batch_size = 2
x1 = torch.randn(batch_size, irreps_in1.dim)
x2 = torch.randn(batch_size, irreps_in2.dim)
w = torch.randn(1, etp.weight_numel)

# 执行前向传播
y = module(w, x1, x2)
print(f"输出形状: {y.shape}")  # 应该是 [batch_size, irreps_out.dim]

在 JAX 中使用 EquivariantTensorProduct

import cuequivariance as cue
import cuequivariance_jax as cuex
import jax
import jax.numpy as jnp

# 创建等变张量积描述符
irreps_in1 = cue.Irreps("SO3", "10x0 + 5x1")
irreps_in2 = cue.Irreps("SO3", "3x0 + 2x1")
irreps_out = cue.Irreps("SO3", "5x0 + 3x1")
etp = cue.descriptors.tensor_product(irreps_in1, irreps_in2, irreps_out)

# 创建随机输入
key = jax.random.key(0)
key1, key2, key3 = jax.random.split(key, 3)
batch_size = 2
x1 = jax.random.normal(key1, (batch_size, irreps_in1.dim))
x2 = jax.random.normal(key2, (batch_size, irreps_in2.dim))
w = jax.random.normal(key3, (etp.weight_numel,))

# 执行等变张量积
y = cuex.equivariant_tensor_product(etp, w, x1, x2)
print(f"输出形状: {y.shape}")  # 应该是 [batch_size, irreps_out.dim]

4.3.2 常用张量积描述符

cuEquivariance 提供了多种预定义的张量积描述符,用于常见的等变操作。这些描述符可以通过 cue.descriptors 模块访问。

1. 线性层

线性层是最基本的等变操作,它将一个表示映射到另一个表示,同时保持等变性。

import cuequivariance as cue

# 创建一个等变线性层描述符
irreps_in = cue.Irreps("SO3", "10x0 + 5x1")
irreps_out = cue.Irreps("SO3", "5x0 + 3x1")
linear = cue.descriptors.linear(irreps_in, irreps_out)

print(f"输入表示: {linear.irreps_in}")
print(f"输出表示: {linear.irreps_out}")
print(f"权重参数数量: {linear.weight_numel}")

线性层只允许相同类型的不可约表示之间有连接。例如,标量只能映射到标量,向量只能映射到向量,等等。这确保了等变性。

2. 球谐函数

球谐函数是一种特殊的等变映射,它将输入向量映射到球谐函数值。

import cuequivariance as cue

# 创建一个球谐函数描述符
# 将 3D 向量映射到 l=0,1,2,3 的球谐函数
sh = cue.descriptors.spherical_harmonics(cue.SO3(1), [0, 1, 2, 3])

print(f"输入表示: {sh.irreps_in}")  # 应该是 1x1
print(f"输出表示: {sh.irreps_out}")  # 应该是 1x0 + 1x1 + 1x2 + 1x3

球谐函数是构建等变卷积网络的重要组件,特别是在处理 3D 数据时。

3. 旋转

旋转描述符用于实现输入的旋转变换。

import cuequivariance as cue

# 创建一个旋转描述符
irreps = cue.Irreps("SO3", "10x0 + 5x1")
rotation = cue.descriptors.yxy_rotation(irreps)

print(f"输入表示: {rotation.irreps_in}")  # 应该与 irreps 相同
print(f"输出表示: {rotation.irreps_out}")  # 应该与 irreps 相同

旋转操作保持输入和输出表示相同,但应用旋转变换。

4. 张量积

一般的张量积允许我们组合两个输入表示,生成一个新的输出表示。

import cuequivariance as cue

# 创建一个张量积描述符
irreps_in1 = cue.Irreps("SO3", "1x0 + 1x1")
irreps_in2 = cue.Irreps("SO3", "1x1")
# 不指定输出表示,将自动计算所有可能的输出
tensor_product = cue.descriptors.tensor_product(irreps_in1, irreps_in2)

print(f"输入表示: {tensor_product.irreps_in}")
print(f"输出表示: {tensor_product.irreps_out}")

张量积遵循群表示理论中的张量积分解规则。例如,两个 l=1 表示的张量积分解为 l=0、l=1 和 l=2 表示的直和。

5. 通道间张量积

通道间张量积在每个通道上独立应用相同的张量积操作。

import cuequivariance as cue

# 创建一个通道间张量积描述符
irreps_in1 = cue.Irreps("SO3", "5x0 + 5x1")  # 5 个通道
irreps_in2 = cue.Irreps("SO3", "5x1")        # 5 个通道
irreps_out = cue.Irreps("SO3", "5x0 + 5x1 + 5x2")  # 每个通道的输出
channelwise_tp = cue.descriptors.channelwise_tensor_product(
    irreps_in1, irreps_in2, irreps_out
)

print(f"输入表示: {channelwise_tp.irreps_in}")
print(f"输出表示: {channelwise_tp.irreps_out}")

通道间张量积在处理多通道数据时特别有用,如图像或点云特征。

6. 自定义张量积

对于更复杂的需求,我们可以创建自定义的张量积描述符:

import cuequivariance as cue
import numpy as np

# 创建一个自定义张量积描述符
def create_custom_tensor_product(irreps_in1, irreps_in2, irreps_out):
    # 创建一个空的等变张量积
    etp = cue.EquivariantTensorProduct(
        irreps_in=[irreps_in1, irreps_in2],
        irreps_out=irreps_out
    )
    
    # 添加自定义的分段张量积描述符
    # 这里只是一个示例,实际应用中需要根据具体需求设计
    d = cue.SegmentedTensorProduct.from_subscripts("uv,iu,iv")
    
    # 添加段和路径
    # ...
    
    # 将描述符添加到等变张量积中
    etp.ds.append(d)
    
    return etp

# 使用自定义函数创建张量积描述符
irreps_in1 = cue.Irreps("SO3", "3x0 + 2x1")
irreps_in2 = cue.Irreps("SO3", "2x0 + 1x1")
irreps_out = cue.Irreps("SO3", "1x0 + 1x1")
custom_tp = create_custom_tensor_product(irreps_in1, irreps_in2, irreps_out)

自定义张量积提供了最大的灵活性,但需要深入理解底层的分段张量积机制。

通过这些预定义的描述符,cuEquivariance 使构建复杂的等变神经网络变得简单而灵活。我们可以组合这些基本构件来创建适合特定应用需求的等变架构。

NVIDIA cuEquivariance 详细教程:基本示例:线性层

5.1 创建等变线性层

等变线性层是构建等变神经网络的基础组件,它允许我们在保持等变性的同时,将一种表示映射到另一种表示。在本节中,我们将详细介绍如何使用 cuEquivariance 创建和使用等变线性层。

等变线性层的原理

等变线性层与普通神经网络中的线性层(全连接层)类似,但有一个关键区别:等变线性层必须保持输入数据的对称性。这意味着只有相同类型的不可约表示之间才能有连接。例如:

  • 标量(l=0)只能映射到标量
  • 向量(l=1)只能映射到向量
  • 二阶张量(l=2)只能映射到二阶张量

这种限制确保了层的等变性,即当输入经过变换(如旋转)时,输出会以一致的方式变换。

使用 cuEquivariance 创建等变线性层

让我们从一个简单的例子开始,创建一个将 SO(3) 表示映射到另一个 SO(3) 表示的等变线性层:

# 导入必要的库
import cuequivariance as cue
import torch
import cuequivariance_torch as cuet

# 定义输入和输出表示
irreps_in = cue.Irreps("SO3", "10x0 + 5x1")   # 10个标量和5个向量
irreps_out = cue.Irreps("SO3", "5x0 + 3x1")   # 5个标量和3个向量

# 创建等变线性层描述符
linear_descriptor = cue.descriptors.linear(irreps_in, irreps_out)

# 查看描述符的基本信息
print(f"输入表示: {linear_descriptor.irreps_in}")
print(f"输出表示: {linear_descriptor.irreps_out}")
print(f"权重参数数量: {linear_descriptor.weight_numel}")

在这个例子中:

  • 输入表示包含 10 个标量(l=0,每个 1 维)和 5 个向量(l=1,每个 3 维),总维度为 10×1 + 5×3 = 25
  • 输出表示包含 5 个标量(l=0,每个 1 维)和 3 个向量(l=1,每个 3 维),总维度为 5×1 + 3×3 = 14
  • 权重参数数量为 10×5 + 5×3 = 50 + 15 = 65,对应于标量到标量的连接和向量到向量的连接

理解权重参数

等变线性层的权重参数数量取决于输入和输出表示中相同类型的不可约表示的多重性。对于每对相同类型的不可约表示,权重数量是输入多重性乘以输出多重性。

例如,对于上面的例子:

  • 标量(l=0):输入多重性为 10,输出多重性为 5,因此有 10×5 = 50 个权重参数
  • 向量(l=1):输入多重性为 5,输出多重性为 3,因此有 5×3 = 15 个权重参数
  • 总权重参数数量为 50 + 15 = 65

这些权重参数控制着不同通道之间的连接强度,同时保持每种不可约表示的内部结构不变,从而确保等变性。

创建更复杂的等变线性层

我们也可以创建更复杂的等变线性层,包含更多类型的不可约表示:

# 创建包含更多不可约表示类型的等变线性层
irreps_in_complex = cue.Irreps("O3", "10x0e + 5x1o + 3x2e")   # 包含偶宇称标量、奇宇称向量和偶宇称二阶张量
irreps_out_complex = cue.Irreps("O3", "5x0e + 3x1o + 2x2e")   # 输出类似的结构但多重性不同

# 创建等变线性层描述符
linear_complex = cue.descriptors.linear(irreps_in_complex, irreps_out_complex)

# 查看描述符的基本信息
print(f"输入表示: {linear_complex.irreps_in}")
print(f"输出表示: {linear_complex.irreps_out}")
print(f"权重参数数量: {linear_complex.weight_numel}")

在这个更复杂的例子中,我们使用了 O(3) 群(包括旋转和反射),并添加了二阶张量(l=2)表示。权重参数数量将是:

  • 偶宇称标量(0e):10×5 = 50
  • 奇宇称向量(1o):5×3 = 15
  • 偶宇称二阶张量(2e):3×2 = 6
  • 总计:50 + 15 + 6 = 71

5.2 在 PyTorch 中使用等变线性层

一旦我们创建了等变线性层描述符,就可以在 PyTorch 中使用它来构建等变神经网络。下面是一个完整的例子:

import torch
import cuequivariance as cue
import cuequivariance_torch as cuet

class EquivariantMLP(torch.nn.Module):
    """
    一个简单的等变多层感知机
    """
    def __init__(self, irreps_in, irreps_hidden, irreps_out):
        """
        初始化等变MLP
        
        参数:
            irreps_in: 输入表示
            irreps_hidden: 隐藏层表示
            irreps_out: 输出表示
        """
        super().__init__()
        
        # 创建第一个等变线性层
        self.linear1_desc = cue.descriptors.linear(irreps_in, irreps_hidden)
        self.linear1 = cuet.EquivariantTensorProduct(
            self.linear1_desc, 
            layout=cue.ir_mul,  # 使用(ir, mul)布局
            use_fallback=True   # 如果CUDA内核不可用,使用回退实现
        )
        
        # 创建第二个等变线性层
        self.linear2_desc = cue.descriptors.linear(irreps_hidden, irreps_out)
        self.linear2 = cuet.EquivariantTensorProduct(
            self.linear2_desc, 
            layout=cue.ir_mul,
            use_fallback=True
        )
        
        # 初始化权重
        self.weights1 = torch.nn.Parameter(torch.randn(1, self.linear1_desc.weight_numel))
        self.weights2 = torch.nn.Parameter(torch.randn(1, self.linear2_desc.weight_numel))
    
    def forward(self, x):
        """
        前向传播
        
        参数:
            x: 输入张量,形状为[batch_size, irreps_in.dim]
            
        返回:
            输出张量,形状为[batch_size, irreps_out.dim]
        """
        # 第一个线性层
        h = self.linear1(self.weights1, x)
        
        # 非线性激活(只应用于标量部分)
        # 注意:我们需要小心处理非标量部分,以保持等变性
        # 这里我们只对标量部分应用ReLU
        scalar_dim = sum(mul * ir.dim for mul, ir in self.linear1_desc.irreps_out if ir.l == 0)
        if scalar_dim > 0:
            h_scalar = h[:, :scalar_dim]
            h_vector = h[:, scalar_dim:]
            h_scalar = torch.relu(h_scalar)
            h = torch.cat([h_scalar, h_vector], dim=1)
        
        # 第二个线性层
        out = self.linear2(self.weights2, h)
        
        return out

# 创建一个等变MLP实例
irreps_in = cue.Irreps("SO3", "10x0 + 5x1")       # 输入:10个标量和5个向量
irreps_hidden = cue.Irreps("SO3", "20x0 + 10x1")  # 隐藏层:20个标量和10个向量
irreps_out = cue.Irreps("SO3", "5x0 + 3x1")       # 输出:5个标量和3个向量

model = EquivariantMLP(irreps_in, irreps_hidden, irreps_out)

# 创建一些随机输入数据
batch_size = 4
x = torch.randn(batch_size, irreps_in.dim)

# 前向传播
y = model(x)

print(f"输入形状: {x.shape}")  # [4, 25]
print(f"输出形状: {y.shape}")  # [4, 14]

在这个例子中,我们创建了一个简单的等变多层感知机(MLP),包含两个等变线性层和一个只应用于标量部分的ReLU激活函数。这种选择性地应用非线性是必要的,因为直接对非标量部分应用标准激活函数会破坏等变性。

处理非线性激活

在等变神经网络中,处理非线性激活是一个挑战,因为大多数标准激活函数(如ReLU、Sigmoid等)不保持等变性。有几种方法可以解决这个问题:

  1. 只对标量部分应用非线性:如上例所示,我们可以只对标量部分应用标准激活函数,因为标量在旋转下不变。

  2. 使用门控非线性:我们可以使用标量特征来调制非标量特征,例如:

def gated_nonlinearity(features, irreps):
    """
    使用标量门控的非线性激活
    
    参数:
        features: 特征张量
        irreps: 特征的不可约表示
        
    返回:
        应用非线性后的特征
    """
    # 分离标量和非标量部分
    scalar_indices = []
    vector_indices = []
    
    start_idx = 0
    for i, (mul, ir) in enumerate(irreps):
        dim = mul * ir.dim
        if ir.l == 0:  # 标量
            scalar_indices.extend(range(start_idx, start_idx + dim))
        else:  # 非标量
            vector_indices.extend(range(start_idx, start_idx + dim))
        start_idx += dim
    
    # 提取标量和非标量部分
    scalar_features = features[:, scalar_indices]
    vector_features = features[:, vector_indices]
    
    # 应用非线性到标量部分
    scalar_features = torch.sigmoid(scalar_features)
    
    # 使用标量门控非标量部分
    # 假设标量和非标量部分的数量匹配(简化示例)
    gated_vector_features = vector_features * scalar_features.unsqueeze(-1)
    
    # 重新组合特征
    combined_features = torch.zeros_like(features)
    combined_features[:, scalar_indices] = scalar_features
    combined_features[:, vector_indices] = gated_vector_features
    
    return combined_features
  1. 使用等变非线性:一些特殊设计的非线性函数可以保持等变性,如球谐函数非线性。

使用预训练权重

在实际应用中,我们通常会训练模型并保存权重。以下是如何使用预训练权重的示例:

# 假设我们已经训练了模型并保存了权重
torch.save(model.state_dict(), "equivariant_mlp_weights.pth")

# 创建一个新模型并加载权重
new_model = EquivariantMLP(irreps_in, irreps_hidden, irreps_out)
new_model.load_state_dict(torch.load("equivariant_mlp_weights.pth"))

# 使用加载的模型进行推理
with torch.no_grad():
    test_input = torch.randn(1, irreps_in.dim)
    prediction = new_model(test_input)
    print(f"预测结果形状: {prediction.shape}")

5.3 在 JAX 中使用等变线性层

cuEquivariance 也提供了 JAX 支持,使我们可以在 JAX 生态系统中使用等变神经网络。以下是一个在 JAX 中使用等变线性层的例子:

import jax
import jax.numpy as jnp
import cuequivariance as cue
import cuequivariance_jax as cuex

class EquivariantMLPJax:
    """
    JAX版本的等变多层感知机
    """
    def __init__(self, irreps_in, irreps_hidden, irreps_out, key):
        """
        初始化等变MLP
        
        参数:
            irreps_in: 输入表示
            irreps_hidden: 隐藏层表示
            irreps_out: 输出表示
            key: JAX随机数生成器密钥
        """
        # 创建第一个等变线性层描述符
        self.linear1_desc = cue.descriptors.linear(irreps_in, irreps_hidden)
        
        # 创建第二个等变线性层描述符
        self.linear2_desc = cue.descriptors.linear(irreps_hidden, irreps_out)
        
        # 初始化权重
        key1, key2 = jax.random.split(key)
        self.weights1 = jax.random.normal(key1, (self.linear1_desc.weight_numel,))
        self.weights2 = jax.random.normal(key2, (self.linear2_desc.weight_numel,))
        
        # 记录表示信息
        self.irreps_in = irreps_in
        self.irreps_hidden = irreps_hidden
        self.irreps_out = irreps_out
    
    def __call__(self, x):
        """
        前向传播
        
        参数:
            x: 输入张量,形状为[batch_size, irreps_in.dim]
            
        返回:
            输出张量,形状为[batch_size, irreps_out.dim]
        """
        # 第一个线性层
        h = cuex.equivariant_tensor_product(self.linear1_desc, self.weights1, x)
        
        # 非线性激活(只应用于标量部分)
        scalar_dim = sum(mul * ir.dim for mul, ir in self.linear1_desc.irreps_out if ir.l == 0)
        if scalar_dim > 0:
            h_scalar = h[:, :scalar_dim]
            h_vector = h[:, scalar_dim:]
            h_scalar = jax.nn.relu(h_scalar)
            h = jnp.concatenate([h_scalar, h_vector], axis=1)
        
        # 第二个线性层
        out = cuex.equivariant_tensor_product(self.linear2_desc, self.weights2, h)
        
        return out

# 创建一个等变MLP实例
irreps_in = cue.Irreps("SO3", "10x0 + 5x1")       # 输入:10个标量和5个向量
irreps_hidden = cue.Irreps("SO3", "20x0 + 10x1")  # 隐藏层:20个标量和10个向量
irreps_out = cue.Irreps("SO3", "5x0 + 3x1")       # 输出:5个标量和3个向量

key = jax.random.key(0)
model = EquivariantMLPJax(irreps_in, irreps_hidden, irreps_out, key)

# 创建一些随机输入数据
batch_size = 4
input_key = jax.random.key(1)
x = jax.random.normal(input_key, (batch_size, irreps_in.dim))

# 前向传播
y = model(x)

print(f"输入形状: {x.shape}")  # (4, 25)
print(f"输出形状: {y.shape}")  # (4, 14)

JAX 版本的实现与 PyTorch 版本在概念上是相似的,但使用了 JAX 的函数式编程风格和 cuex.equivariant_tensor_product 函数而不是 PyTorch 模块。

JAX 中的 JIT 编译

JAX 的一个主要优势是能够使用即时编译(JIT)来加速计算。我们可以轻松地将等变网络的前向传播函数进行 JIT 编译:

# JIT编译前向传播函数
@jax.jit
def forward(model, x):
    return model(x)

# 使用JIT编译的函数
y_jit = forward(model, x)

print(f"JIT编译后的输出形状: {y_jit.shape}")  # (4, 14)

JAX 中的批量处理

JAX 还允许我们轻松地对函数进行向量化,以处理批量数据:

# 向量化前向传播函数以处理批量数据
batch_forward = jax.vmap(lambda m, x: m(x), in_axes=(None, 0))

# 创建一批输入数据
batch_size = 10
batch_key = jax.random.key(2)
batch_x = jax.random.normal(batch_key, (batch_size, irreps_in.dim))

# 使用向量化函数处理批量数据
batch_y = batch_forward(model, batch_x)

print(f"批量输出形状: {batch_y.shape}")  # (10, 14)

5.4 等变线性层的性能分析

等变线性层的性能是构建高效等变神经网络的关键考虑因素。在本节中,我们将分析等变线性层的计算复杂度和内存使用,并提供一些优化技巧。

计算复杂度

等变线性层的计算复杂度主要取决于以下因素:

  1. 输入和输出表示的维度:维度越高,计算量越大
  2. 不可约表示的类型和多重性:更高阶的不可约表示(如 l=2, l=3 等)需要更多计算
  3. 批量大小:更大的批量会线性增加计算量

对于标准等变线性层,计算复杂度大约为 O(N_in * N_out),其中 N_in 和 N_out 分别是输入和输出表示的总维度。

内存使用

等变线性层的内存使用主要包括:

  1. 权重参数:数量取决于输入和输出表示中相同类型不可约表示的多重性
  2. 中间激活:在前向和反向传播过程中需要存储的中间结果
  3. 梯度:训练过程中计算的梯度

性能优化技巧

以下是一些优化等变线性层性能的技巧:

1. 选择合适的不可约表示

# 不推荐:使用过多高阶不可约表示
irreps_inefficient = cue.Irreps("SO3", "10x0 + 10x1 + 10x2 + 10x3 + 10x4")

# 推荐:专注于低阶不可约表示,减少高阶表示的数量
irreps_efficient = cue.Irreps("SO3", "20x0 + 15x1 + 5x2")

2. 使用批处理来提高 GPU 利用率

# 处理单个样本
single_input = torch.randn(1, irreps_in.dim)
single_output = model(single_input)  # 低GPU利用率

# 批处理多个样本
batch_input = torch.randn(32, irreps_in.dim)
batch_output = model(batch_input)  # 更高的GPU利用率

3. 在 PyTorch 中使用混合精度训练

from torch.cuda.amp import autocast, GradScaler

# 创建梯度缩放器
scaler = GradScaler()

# 在训练循环中
optimizer.zero_grad()

# 使用混合精度
with autocast():
    output = model(input_data)
    loss = loss_function(output, target)

# 缩放梯度并优化
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()

4. 在 JAX 中使用 XLA 编译器优化

# 使用JAX的XLA编译器优化计算
@jax.jit
def train_step(model, x, y, optimizer_state):
    def loss_fn(params):
        # 使用参数更新模型
        updated_model = replace_params(model, params)
        # 计算预测和损失
        pred = updated_model(x)
        loss = loss_function(pred, y)
        return loss
    
    # 计算梯度
    grad = jax.grad(loss_fn)(optimizer_state.params)
    # 更新优化器状态
    new_optimizer_state = optimizer.update(grad, optimizer_state)
    return new_optimizer_state

5. 使用 cuEquivariance 的 JIT 内核(如果可用)

import os

# 启用JIT内核
os.environ["CUEQUIVARIANCE_OPS_USE_JIT"] = "1"

# 创建等变模块
module = cuet.EquivariantTensorProduct(
    descriptor, 
    layout=cue.ir_mul,
    use_fallback=True  # 如果JIT内核不可用,回退到标准实现
)

6. 监控和分析性能

# 在PyTorch中使用性能分析器
from torch.profiler import profile, record_function, ProfilerActivity

with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA]) as prof:
    with record_function("model_inference"):
        output = model(input_data)

print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))

通过这些优化技巧,我们可以显著提高等变线性层的性能,使等变神经网络在实际应用中更加实用。

与传统线性层的性能对比

等变线性层通常比传统线性层需要更多的计算资源,因为它们需要保持等变性约束。然而,这种额外的计算成本通常可以通过以下优势来抵消:

  1. 更少的参数:等变约束减少了需要学习的参数数量
  2. 更好的泛化:等变性使模型能够更好地泛化到未见过的数据
  3. 更少的训练数据:由于内置的对称性,等变网络通常需要更少的训练数据

以下是一个简单的对比示例:

import torch
import torch.nn as nn
import cuequivariance as cue
import cuequivariance_torch as cuet
import time

# 创建一个传统的线性层
in_features = 25  # 与等变层的输入维度相同
out_features = 14  # 与等变层的输出维度相同
traditional_linear = nn.Linear(in_features, out_features)

# 创建一个等变线性层
irreps_in = cue.Irreps("SO3", "10x0 + 5x1")  # 总维度为25
irreps_out = cue.Irreps("SO3", "5x0 + 3x1")  # 总维度为14
equivariant_desc = cue.descriptors.linear(irreps_in, irreps_out)
equivariant_linear = cuet.EquivariantTensorProduct(
    equivariant_desc, 
    layout=cue.ir_mul,
    use_fallback=True
)
equivariant_weights = torch.randn(1, equivariant_desc.weight_numel)

# 创建输入数据
batch_size = 1000
x = torch.randn(batch_size, in_features)

# 测量传统线性层的性能
start_time = time.time()
traditional_output = traditional_linear(x)
traditional_time = time.time() - start_time

# 测量等变线性层的性能
start_time = time.time()
equivariant_output = equivariant_linear(equivariant_weights, x)
equivariant_time = time.time() - start_time

print(f"传统线性层时间: {traditional_time:.6f} 秒")
print(f"等变线性层时间: {equivariant_time:.6f} 秒")
print(f"速度比 (传统/等变): {traditional_time/equivariant_time:.2f}")

# 比较参数数量
traditional_params = sum(p.numel() for p in traditional_linear.parameters())
equivariant_params = equivariant_desc.weight_numel

print(f"传统线性层参数数量: {traditional_params}")
print(f"等变线性层参数数量: {equivariant_params}")
print(f"参数比 (传统/等变): {traditional_params/equivariant_params:.2f}")

这个对比展示了等变线性层与传统线性层在性能和参数数量方面的差异。虽然等变线性层可能在计算上更昂贵,但它们通常需要更少的参数,并且在处理具有对称性的数据时提供更好的性能。

NVIDIA cuEquivariance 详细教程:基本示例:张量积

6.1 等变张量积的基本概念

张量积是等变神经网络中的一个核心操作,它允许我们将两个等变特征组合起来,生成一个新的等变特征。在本节中,我们将详细介绍如何使用 cuEquivariance 创建和使用等变张量积。

张量积的数学基础

从数学角度看,两个表示 V 和 W 的张量积 V ⊗ W 是一个新的表示,其变换行为由原始表示的变换行为共同决定。对于群 G 的两个表示 ρ_V 和 ρ_W,它们的张量积表示 ρ_{V⊗W} 定义为:

ρ_{V⊗W}(g) = ρ_V(g) ⊗ ρ_W(g)

其中 g 是群 G 中的元素,⊗ 表示克罗内克积(Kronecker product)。

在等变神经网络中,我们通常关注的是 SO(3) 或 O(3) 群的不可约表示的张量积。这些张量积可以分解为不可约表示的直和,遵循特定的分解规则:

对于 SO(3) 群,两个不可约表示 D^l₁ 和 D^l₂ 的张量积分解为:

D^l₁ ⊗ D^l₂ = ⊕_{l=|l₁-l₂|}^{l₁+l₂} D^l

例如,两个 l=1 表示(向量)的张量积分解为:

D^1 ⊗ D^1 = D^0 ⊕ D^1 ⊕ D^2

这对应于物理中熟悉的向量乘法分解:两个向量的乘积可以分解为一个标量(点积,l=0)、一个向量(叉积,l=1)和一个对称无迹张量(l=2)。

等变张量积的作用

等变张量积在等变神经网络中有多种重要作用:

  1. 特征组合:将不同类型的特征组合成新的特征,同时保持等变性
  2. 非线性引入:通过组合特征引入非线性,而不破坏等变性
  3. 信息交互:允许不同通道和不同类型的特征之间交换信息
  4. 高阶特征生成:从低阶特征生成高阶特征,如从向量生成二阶张量

6.2 使用 cuEquivariance 创建等变张量积

cuEquivariance 提供了强大的工具来创建和使用等变张量积。让我们从一个简单的例子开始:

# 导入必要的库
import cuequivariance as cue
import torch
import cuequivariance_torch as cuet

# 定义两个输入表示
irreps_in1 = cue.Irreps("SO3", "1x0 + 1x1")  # 一个标量和一个向量
irreps_in2 = cue.Irreps("SO3", "1x1")        # 一个向量

# 创建等变张量积描述符
# 不指定输出表示,将自动计算所有可能的输出
tensor_product = cue.descriptors.tensor_product(irreps_in1, irreps_in2)

# 查看描述符的基本信息
print(f"输入表示1: {tensor_product.irreps_in[0]}")
print(f"输入表示2: {tensor_product.irreps_in[1]}")
print(f"输出表示: {tensor_product.irreps_out}")
print(f"权重参数数量: {tensor_product.weight_numel}")

在这个例子中:

  • 第一个输入表示包含一个标量(l=0)和一个向量(l=1)
  • 第二个输入表示包含一个向量(l=1)
  • 输出表示将自动计算,包含所有可能的张量积分解结果

根据 SO(3) 表示的张量积规则:

  • 0 ⊗ 1 = 1(标量与向量的乘积是向量)
  • 1 ⊗ 1 = 0 ⊕ 1 ⊕ 2(向量与向量的乘积分解为标量、向量和二阶张量)

因此,输出表示应该是 “1x0 + 2x1 + 1x2”。

指定输出表示

我们也可以明确指定输出表示,只保留我们感兴趣的部分:

# 指定输出表示,只保留标量和向量部分
irreps_out = cue.Irreps("SO3", "1x0 + 2x1")  # 只保留标量和向量,忽略二阶张量
tensor_product_specific = cue.descriptors.tensor_product(
    irreps_in1, irreps_in2, irreps_out
)

print(f"输入表示1: {tensor_product_specific.irreps_in[0]}")
print(f"输入表示2: {tensor_product_specific.irreps_in[1]}")
print(f"输出表示: {tensor_product_specific.irreps_out}")
print(f"权重参数数量: {tensor_product_specific.weight_numel}")

这种方法允许我们控制张量积的输出,只保留对特定应用有用的部分。

理解权重参数

等变张量积的权重参数控制着不同路径的贡献。每个路径对应于一种从输入到输出的映射方式,遵循等变性约束。

权重参数的数量取决于:

  1. 输入表示的结构
  2. 输出表示的结构
  3. 可能的等变映射路径数量

对于复杂的张量积,权重参数的结构可能很复杂。cuEquivariance 会自动处理这些细节,确保所有操作保持等变性。

6.3 在 PyTorch 中使用等变张量积

一旦我们创建了等变张量积描述符,就可以在 PyTorch 中使用它来构建等变神经网络。下面是一个完整的例子:

import torch
import cuequivariance as cue
import cuequivariance_torch as cuet

class EquivariantNetwork(torch.nn.Module):
    """
    使用等变张量积的神经网络
    """
    def __init__(self):
        """
        初始化网络
        """
        super().__init__()
        
        # 定义表示
        self.irreps_in1 = cue.Irreps("SO3", "10x0 + 5x1")  # 第一个输入
        self.irreps_in2 = cue.Irreps("SO3", "3x0 + 2x1")   # 第二个输入
        self.irreps_hidden = cue.Irreps("SO3", "5x0 + 3x1 + 2x2")  # 隐藏层
        self.irreps_out = cue.Irreps("SO3", "1x0")  # 输出(标量)
        
        # 创建第一个等变张量积
        self.tp1_desc = cue.descriptors.tensor_product(
            self.irreps_in1, self.irreps_in2, self.irreps_hidden
        )
        self.tp1 = cuet.EquivariantTensorProduct(
            self.tp1_desc, 
            layout=cue.ir_mul,
            use_fallback=True
        )
        
        # 创建第二个等变张量积(映射到标量输出)
        self.tp2_desc = cue.descriptors.tensor_product(
            self.irreps_hidden, self.irreps_hidden, self.irreps_out
        )
        self.tp2 = cuet.EquivariantTensorProduct(
            self.tp2_desc, 
            layout=cue.ir_mul,
            use_fallback=True
        )
        
        # 初始化权重
        self.weights1 = torch.nn.Parameter(torch.randn(1, self.tp1_desc.weight_numel))
        self.weights2 = torch.nn.Parameter(torch.randn(1, self.tp2_desc.weight_numel))
    
    def forward(self, x1, x2):
        """
        前向传播
        
        参数:
            x1: 第一个输入张量,形状为[batch_size, irreps_in1.dim]
            x2: 第二个输入张量,形状为[batch_size, irreps_in2.dim]
            
        返回:
            输出张量,形状为[batch_size, irreps_out.dim]
        """
        # 第一个张量积
        h = self.tp1(self.weights1, x1, x2)
        
        # 非线性激活(只应用于标量部分)
        scalar_dim = sum(mul * ir.dim for mul, ir in self.irreps_hidden if ir.l == 0)
        if scalar_dim > 0:
            h_scalar = h[:, :scalar_dim]
            h_vector = h[:, scalar_dim:]
            h_scalar = torch.relu(h_scalar)
            h = torch.cat([h_scalar, h_vector], dim=1)
        
        # 第二个张量积(自我交互,产生标量输出)
        out = self.tp2(self.weights2, h, h)
        
        return out

# 创建网络实例
model = EquivariantNetwork()

# 创建一些随机输入数据
batch_size = 4
x1 = torch.randn(batch_size, model.irreps_in1.dim)
x2 = torch.randn(batch_size, model.irreps_in2.dim)

# 前向传播
y = model(x1, x2)

print(f"输入1形状: {x1.shape}")
print(f"输入2形状: {x2.shape}")
print(f"输出形状: {y.shape}")  # 应该是 [batch_size, 1]

在这个例子中,我们创建了一个使用两个等变张量积的神经网络:

  1. 第一个张量积将两个输入表示组合成一个隐藏表示
  2. 第二个张量积将隐藏表示与自身交互,生成一个标量输出

这种架构在处理成对数据(如分子间相互作用)时特别有用。

处理多个输入

等变张量积可以处理两个或更多输入。对于两个以上的输入,我们可以级联多个张量积操作:

# 处理三个输入的例子
def process_three_inputs(x1, x2, x3, model):
    """
    处理三个等变输入
    
    参数:
        x1, x2, x3: 三个输入张量
        model: 包含等变张量积的模型
        
    返回:
        处理后的输出
    """
    # 首先组合前两个输入
    intermediate = model.tp1(model.weights1, x1, x2)
    
    # 然后将结果与第三个输入组合
    output = model.tp2(model.weights2, intermediate, x3)
    
    return output

实现自注意力机制

等变张量积还可以用于实现等变自注意力机制,这在处理点云或图数据时特别有用:

class EquivariantSelfAttention(torch.nn.Module):
    """
    等变自注意力机制
    """
    def __init__(self, irreps_in, irreps_out):
        """
        初始化等变自注意力层
        
        参数:
            irreps_in: 输入表示
            irreps_out: 输出表示
        """
        super().__init__()
        
        # 定义查询、键、值的表示
        self.irreps_qk = cue.Irreps("SO3", "5x0")  # 查询和键使用标量
        self.irreps_v = irreps_in  # 值使用与输入相同的表示
        
        # 创建查询、键、值的线性映射
        self.query_desc = cue.descriptors.linear(irreps_in, self.irreps_qk)
        self.key_desc = cue.descriptors.linear(irreps_in, self.irreps_qk)
        self.value_desc = cue.descriptors.linear(irreps_in, self.irreps_v)
        
        self.query = cuet.EquivariantTensorProduct(
            self.query_desc, layout=cue.ir_mul, use_fallback=True
        )
        self.key = cuet.EquivariantTensorProduct(
            self.key_desc, layout=cue.ir_mul, use_fallback=True
        )
        self.value = cuet.EquivariantTensorProduct(
            self.value_desc, layout=cue.ir_mul, use_fallback=True
        )
        
        # 创建输出映射
        self.output_desc = cue.descriptors.linear(self.irreps_v, irreps_out)
        self.output = cuet.EquivariantTensorProduct(
            self.output_desc, layout=cue.ir_mul, use_fallback=True
        )
        
        # 初始化权重
        self.query_weights = torch.nn.Parameter(torch.randn(1, self.query_desc.weight_numel))
        self.key_weights = torch.nn.Parameter(torch.randn(1, self.key_desc.weight_numel))
        self.value_weights = torch.nn.Parameter(torch.randn(1, self.value_desc.weight_numel))
        self.output_weights = torch.nn.Parameter(torch.randn(1, self.output_desc.weight_numel))
    
    def forward(self, x):
        """
        前向传播
        
        参数:
            x: 输入张量,形状为[batch_size, num_points, irreps_in.dim]
            
        返回:
            输出张量,形状为[batch_size, num_points, irreps_out.dim]
        """
        batch_size, num_points, _ = x.shape
        
        # 计算查询、键、值
        q = self.query(self.query_weights, x.reshape(batch_size * num_points, -1))
        q = q.reshape(batch_size, num_points, -1)
        
        k = self.key(self.key_weights, x.reshape(batch_size * num_points, -1))
        k = k.reshape(batch_size, num_points, -1)
        
        v = self.value(self.value_weights, x.reshape(batch_size * num_points, -1))
        v = v.reshape(batch_size, num_points, -1)
        
        # 计算注意力分数(只使用标量部分确保等变性)
        scores = torch.bmm(q, k.transpose(1, 2)) / (self.irreps_qk.dim ** 0.5)
        
        # 应用softmax
        attention = torch.softmax(scores, dim=-1)
        
        # 计算加权和
        out = torch.bmm(attention, v)
        
        # 应用输出映射
        out = self.output(self.output_weights, out.reshape(batch_size * num_points, -1))
        out = out.reshape(batch_size, num_points, -1)
        
        return out

这个等变自注意力实现确保了在旋转输入数据时,注意力机制的输出也会相应地旋转,保持等变性。

6.4 在 JAX 中使用等变张量积

cuEquivariance 也提供了 JAX 支持,使我们可以在 JAX 生态系统中使用等变张量积。以下是一个在 JAX 中使用等变张量积的例子:

import jax
import jax.numpy as jnp
import cuequivariance as cue
import cuequivariance_jax as cuex

class EquivariantNetworkJax:
    """
    JAX版本的等变网络
    """
    def __init__(self, key):
        """
        初始化网络
        
        参数:
            key: JAX随机数生成器密钥
        """
        # 定义表示
        self.irreps_in1 = cue.Irreps("SO3", "10x0 + 5x1")  # 第一个输入
        self.irreps_in2 = cue.Irreps("SO3", "3x0 + 2x1")   # 第二个输入
        self.irreps_hidden = cue.Irreps("SO3", "5x0 + 3x1 + 2x2")  # 隐藏层
        self.irreps_out = cue.Irreps("SO3", "1x0")  # 输出(标量)
        
        # 创建等变张量积描述符
        self.tp1_desc = cue.descriptors.tensor_product(
            self.irreps_in1, self.irreps_in2, self.irreps_hidden
        )
        self.tp2_desc = cue.descriptors.tensor_product(
            self.irreps_hidden, self.irreps_hidden, self.irreps_out
        )
        
        # 初始化权重
        key1, key2 = jax.random.split(key)
        self.weights1 = jax.random.normal(key1, (self.tp1_desc.weight_numel,))
        self.weights2 = jax.random.normal(key2, (self.tp2_desc.weight_numel,))
    
    def __call__(self, x1, x2):
        """
        前向传播
        
        参数:
            x1: 第一个输入张量
            x2: 第二个输入张量
            
        返回:
            输出张量
        """
        # 第一个张量积
        h = cuex.equivariant_tensor_product(self.tp1_desc, self.weights1, x1, x2)
        
        # 非线性激活(只应用于标量部分)
        scalar_dim = sum(mul * ir.dim for mul, ir in self.irreps_hidden if ir.l == 0)
        if scalar_dim > 0:
            h_scalar = h[:, :scalar_dim]
            h_vector = h[:, scalar_dim:]
            h_scalar = jax.nn.relu(h_scalar)
            h = jnp.concatenate([h_scalar, h_vector], axis=1)
        
        # 第二个张量积
        out = cuex.equivariant_tensor_product(self.tp2_desc, self.weights2, h, h)
        
        return out

# 创建网络实例
key = jax.random.key(0)
model = EquivariantNetworkJax(key)

# 创建一些随机输入数据
batch_size = 4
input_key1, input_key2 = jax.random.split(jax.random.key(1))
x1 = jax.random.normal(input_key1, (batch_size, model.irreps_in1.dim))
x2 = jax.random.normal(input_key2, (batch_size, model.irreps_in2.dim))

# 前向传播
y = model(x1, x2)

print(f"输入1形状: {x1.shape}")
print(f"输入2形状: {x2.shape}")
print(f"输出形状: {y.shape}")  # 应该是 (batch_size, 1)

JAX 版本的实现与 PyTorch 版本在概念上是相似的,但使用了 JAX 的函数式编程风格和 cuex.equivariant_tensor_product 函数。

JAX 中的梯度和优化

JAX 提供了强大的自动微分功能,使我们可以轻松计算等变网络的梯度并进行优化:

import optax  # JAX优化库

def loss_fn(model, x1, x2, y_true):
    """
    计算损失函数
    
    参数:
        model: 模型
        x1, x2: 输入
        y_true: 真实标签
        
    返回:
        损失值
    """
    y_pred = model(x1, x2)
    return jnp.mean((y_pred - y_true) ** 2)

# 创建优化器
optimizer = optax.adam(learning_rate=1e-3)

# 初始化优化器状态
params = {'weights1': model.weights1, 'weights2': model.weights2}
opt_state = optimizer.init(params)

# 定义训练步骤
@jax.jit
def train_step(params, opt_state, x1, x2, y_true):
    """
    执行一步训练
    
    参数:
        params: 模型参数
        opt_state: 优化器状态
        x1, x2: 输入
        y_true: 真实标签
        
    返回:
        更新后的参数和优化器状态,以及损失值
    """
    def loss_for_params(p):
        # 创建一个使用给定参数的模型副本
        model_copy = model  # 在实际应用中,应该创建一个真正的副本
        model_copy.weights1 = p['weights1']
        model_copy.weights2 = p['weights2']
        return loss_fn(model_copy, x1, x2, y_true)
    
    # 计算损失和梯度
    loss, grads = jax.value_and_grad(loss_for_params)(params)
    
    # 更新参数
    updates, new_opt_state = optimizer.update(grads, opt_state)
    new_params = optax.apply_updates(params, updates)
    
    return new_params, new_opt_state, loss

# 模拟训练循环
y_true = jnp.zeros((batch_size, 1))  # 假设目标是零

for i in range(10):
    params, opt_state, loss = train_step(params, opt_state, x1, x2, y_true)
    print(f"步骤 {i+1}, 损失: {loss}")

# 更新模型参数
model.weights1 = params['weights1']
model.weights2 = params['weights2']

# 测试优化后的模型
y_optimized = model(x1, x2)
print(f"优化后的输出: {y_optimized}")

这个例子展示了如何在 JAX 中训练等变网络,利用 JAX 的自动微分和优化功能。

6.5 等变张量积的应用案例

等变张量积在多个领域有广泛的应用。以下是一些具体的应用案例:

分子性质预测

在分子性质预测中,等变张量积可以用于组合原子特征和键特征,同时保持对分子旋转和平移的等变性:

def molecular_property_prediction(atomic_features, bond_features, model):
    """
    预测分子性质
    
    参数:
        atomic_features: 原子特征,形状为[num_atoms, irreps_atom.dim]
        bond_features: 键特征,形状为[num_bonds, irreps_bond.dim]
        model: 包含等变张量积的模型
        
    返回:
        预测的分子性质(标量)
    """
    # 使用等变张量积组合原子和键特征
    combined_features = model.tp1(model.weights1, atomic_features, bond_features)
    
    # 应用非线性(只对标量部分)
    scalar_dim = sum(mul * ir.dim for mul, ir in model.irreps_hidden if ir.l == 0)
    combined_features_scalar = combined_features[:, :scalar_dim]
    combined_features_vector = combined_features[:, scalar_dim:]
    combined_features_scalar = torch.relu(combined_features_scalar)
    combined_features = torch.cat([combined_features_scalar, combined_features_vector], dim=1)
    
    # 聚合所有原子的特征(求和池化)
    global_features = combined_features.sum(dim=0, keepdim=True)
    
    # 预测最终的分子性质(标量)
    property_prediction = model.tp2(model.weights2, global_features, global_features)
    
    return property_prediction

点云处理

在点云处理中,等变张量积可以用于组合点特征和局部几何信息:

def point_cloud_processing(points, point_features, model):
    """
    处理点云数据
    
    参数:
        points: 点坐标,形状为[num_points, 3]
        point_features: 点特征,形状为[num_points, irreps_point.dim]
        model: 包含等变张量积的模型
        
    返回:
        处理后的点特征
    """
    # 计算每个点的局部几何描述符(如球谐函数)
    # 假设我们有一个函数来计算这些描述符
    geometric_descriptors = compute_geometric_descriptors(points)
    
    # 使用等变张量积组合点特征和几何描述符
    enhanced_features = model.tp1(model.weights1, point_features, geometric_descriptors)
    
    # 应用非线性(只对标量部分)
    # ...
    
    # 返回增强的点特征
    return enhanced_features

蛋白质结构预测

在蛋白质结构预测中,等变张量积可以用于组合氨基酸特征和局部结构信息:

def protein_structure_prediction(amino_acid_features, local_structure, model):
    """
    预测蛋白质结构
    
    参数:
        amino_acid_features: 氨基酸特征
        local_structure: 局部结构信息
        model: 包含等变张量积的模型
        
    返回:
        预测的蛋白质结构
    """
    # 使用等变张量积组合氨基酸特征和局部结构信息
    combined_features = model.tp1(model.weights1, amino_acid_features, local_structure)
    
    # 应用非线性(只对标量部分)
    # ...
    
    # 预测下一层的结构信息
    predicted_structure = model.tp2(model.weights2, combined_features, combined_features)
    
    return predicted_structure

这些应用案例展示了等变张量积在处理具有几何对称性的数据时的强大能力。通过保持等变性,这些模型能够更有效地学习和泛化,特别是在数据有限的情况下。

NVIDIA cuEquivariance 详细教程:实际应用案例

7.1 球谐函数与旋转等变性

球谐函数是描述三维空间中旋转等变性的强大工具,在等变神经网络中扮演着重要角色。在本节中,我们将详细介绍如何使用 cuEquivariance 中的球谐函数功能,并展示其在实际应用中的价值。

7.1.1 球谐函数的基本概念

球谐函数 Y_l^m(θ,φ) 是定义在球面上的特殊函数,它们构成了球面上平方可积函数的一组完备正交基。在等变神经网络中,球谐函数与 SO(3) 群的不可约表示密切相关:对于给定的角动量量子数 l,2l+1 个球谐函数 Y_l^m(m 从 -l 到 l)形成了 SO(3) 第 l 个不可约表示的一个基。

球谐函数的关键特性是它们在旋转下的变换行为是已知的,这使得它们成为构建等变特征的理想工具。当我们旋转输入向量时,球谐函数的输出会按照特定的方式变换,保持等变性。

7.1.2 使用 cuEquivariance 计算球谐函数

cuEquivariance 提供了高效计算球谐函数的功能。以下是一个基本示例:

# 导入必要的库
import cuequivariance as cue
import cuequivariance_torch as cuet
import torch
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

# 创建球谐函数描述符
# 计算 l=0,1,2,3 的球谐函数
sh_descriptor = cue.descriptors.spherical_harmonics(cue.SO3(1), [0, 1, 2, 3])

# 创建 PyTorch 模块
sh_module = cuet.EquivariantTensorProduct(
    sh_descriptor, 
    layout=cue.ir_mul,
    use_fallback=True
)

# 创建一个单位向量网格
theta = np.linspace(0, np.pi, 100)
phi = np.linspace(0, 2*np.pi, 100)
theta_grid, phi_grid = np.meshgrid(theta, phi)

# 将球坐标转换为笛卡尔坐标
x = np.sin(theta_grid) * np.cos(phi_grid)
y = np.sin(theta_grid) * np.sin(phi_grid)
z = np.cos(theta_grid)

# 将坐标转换为张量
points = torch.tensor(np.stack([x.flatten(), y.flatten(), z.flatten()], axis=1), dtype=torch.float32)

# 创建权重(对于球谐函数,权重通常是单位权重)
weights = torch.ones(1, sh_descriptor.weight_numel)

# 计算球谐函数值
sh_values = sh_module(weights, points)

# 提取 l=2, m=0 的球谐函数值(作为示例)
# 注意:索引取决于球谐函数的具体排列方式
l2m0_index = 1 + 3 + 0  # l=0 有1个,l=1 有3个,然后是 l=2, m=0
l2m0_values = sh_values[:, l2m0_index].reshape(phi.size, theta.size).numpy()

# 可视化
fig = plt.figure(figsize=(10, 8))
ax = fig.add_subplot(111, projection='3d')

# 将球谐函数值映射到颜色
norm = plt.Normalize(l2m0_values.min(), l2m0_values.max())
colors = plt.cm.viridis(norm(l2m0_values))

# 绘制球面
surf = ax.plot_surface(x, y, z, facecolors=colors, alpha=0.7)

# 添加颜色条
cbar = fig.colorbar(plt.cm.ScalarMappable(norm=norm, cmap=plt.cm.viridis), ax=ax)
cbar.set_label('Y_2^0 值')

# 设置标题和标签
ax.set_title('Y_2^0 球谐函数')
ax.set_xlabel('X')
ax.set_ylabel('Y')
ax.set_zlabel('Z')

plt.tight_layout()
plt.savefig('spherical_harmonic_l2m0.png')
plt.close()

print("球谐函数可视化已保存为 'spherical_harmonic_l2m0.png'")

这个例子展示了如何使用 cuEquivariance 计算球谐函数值,并将结果可视化在球面上。我们选择了 l=2, m=0 的球谐函数作为示例,它具有典型的"哑铃"形状。

7.1.3 球谐函数在特征提取中的应用

球谐函数在等变神经网络中的一个重要应用是将三维向量(如原子坐标或点云中的点)映射到等变特征。以下是一个在分子表示中使用球谐函数的例子:

def create_molecular_features(atomic_positions, atomic_types, max_l=3):
    """
    使用球谐函数创建分子的等变特征表示
    
    参数:
        atomic_positions: 原子坐标,形状为 [num_atoms, 3]
        atomic_types: 原子类型,形状为 [num_atoms]
        max_l: 最大角动量量子数
        
    返回:
        分子的等变特征
    """
    # 导入必要的库
    import cuequivariance as cue
    import cuequivariance_torch as cuet
    import torch
    
    # 创建球谐函数描述符
    l_values = list(range(max_l + 1))  # 包括 0 到 max_l
    sh_descriptor = cue.descriptors.spherical_harmonics(cue.SO3(1), l_values)
    
    # 创建 PyTorch 模块
    sh_module = cuet.EquivariantTensorProduct(
        sh_descriptor, 
        layout=cue.ir_mul,
        use_fallback=True
    )
    
    # 创建权重
    weights = torch.ones(1, sh_descriptor.weight_numel)
    
    # 计算每对原子之间的相对位置
    num_atoms = atomic_positions.shape[0]
    features_list = []
    
    for i in range(num_atoms):
        # 计算从当前原子到所有其他原子的相对位置
        rel_pos = atomic_positions - atomic_positions[i:i+1]  # [num_atoms, 3]
        
        # 计算距离
        distances = torch.norm(rel_pos, dim=1, keepdim=True)  # [num_atoms, 1]
        
        # 忽略自身(距离为0)
        mask = distances > 1e-10
        valid_rel_pos = rel_pos[mask.squeeze()]  # [num_valid, 3]
        valid_distances = distances[mask]  # [num_valid, 1]
        valid_types = atomic_types[mask.squeeze()]  # [num_valid]
        
        # 归一化相对位置为单位向量
        unit_vectors = valid_rel_pos / valid_distances
        
        # 计算球谐函数值
        sh_values = sh_module(weights, unit_vectors)  # [num_valid, sum(2*l+1)]
        
        # 使用距离作为径向函数(简单示例,实际应用中可能需要更复杂的径向函数)
        radial_function = torch.exp(-valid_distances)  # [num_valid, 1]
        
        # 将球谐函数值与径向函数和原子类型结合
        # 这里我们简单地将原子类型转换为独热编码
        num_atom_types = len(torch.unique(atomic_types))
        type_encoding = torch.zeros(valid_types.shape[0], num_atom_types)
        type_encoding.scatter_(1, valid_types.unsqueeze(1), 1)
        
        # 组合特征:径向函数 * 球谐函数值 * 原子类型编码
        # 对于每种原子类型,我们得到一组加权的球谐函数值
        combined_features = []
        for j in range(num_atom_types):
            type_mask = type_encoding[:, j:j+1]
            type_features = sh_values * radial_function * type_mask
            # 对同一类型的原子求和
            summed_features = type_features.sum(dim=0, keepdim=True)
            combined_features.append(summed_features)
        
        # 将所有特征连接起来
        atom_features = torch.cat(combined_features, dim=1)
        features_list.append(atom_features)
    
    # 将所有原子的特征堆叠起来
    molecular_features = torch.cat(features_list, dim=0)
    
    return molecular_features

这个函数展示了如何使用球谐函数创建分子的等变特征表示。对于每个原子,我们计算它与所有其他原子的相对位置,然后使用球谐函数将这些相对位置映射到等变特征。我们还结合了径向函数和原子类型信息,创建了一个全面的分子表示。

7.1.4 验证旋转等变性

等变神经网络的一个关键特性是它们在输入旋转时的行为。我们可以通过以下实验来验证球谐函数的旋转等变性:

def verify_rotational_equivariance():
    """
    验证球谐函数的旋转等变性
    """
    # 导入必要的库
    import cuequivariance as cue
    import cuequivariance_torch as cuet
    import torch
    import numpy as np
    from scipy.spatial.transform import Rotation
    
    # 创建球谐函数描述符
    sh_descriptor = cue.descriptors.spherical_harmonics(cue.SO3(1), [0, 1, 2])
    
    # 创建 PyTorch 模块
    sh_module = cuet.EquivariantTensorProduct(
        sh_descriptor, 
        layout=cue.ir_mul,
        use_fallback=True
    )
    
    # 创建权重
    weights = torch.ones(1, sh_descriptor.weight_numel)
    
    # 创建一些随机单位向量
    num_vectors = 10
    random_vectors = torch.randn(num_vectors, 3)
    unit_vectors = random_vectors / torch.norm(random_vectors, dim=1, keepdim=True)
    
    # 计算原始球谐函数值
    original_sh_values = sh_module(weights, unit_vectors)
    
    # 创建一个随机旋转
    rotation = Rotation.random()
    rotation_matrix = torch.tensor(rotation.as_matrix(), dtype=torch.float32)
    
    # 旋转输入向量
    rotated_vectors = torch.matmul(unit_vectors, rotation_matrix.T)
    
    # 计算旋转后的球谐函数值
    rotated_sh_values = sh_module(weights, rotated_vectors)
    
    # 计算球谐函数值的理论变换
    # 注意:这需要使用 Wigner D-矩阵,这里我们使用 cuEquivariance 的旋转功能
    
    # 创建旋转描述符
    irreps_sh = cue.Irreps("SO3", "1x0 + 1x1 + 1x2")  # 对应于我们计算的球谐函数
    rotation_desc = cue.descriptors.yxy_rotation(irreps_sh)
    
    # 创建旋转模块
    rotation_module = cuet.EquivariantTensorProduct(
        rotation_desc, 
        layout=cue.ir_mul,
        use_fallback=True
    )
    
    # 从旋转矩阵创建 YXY 欧拉角
    # 注意:这是一个简化,实际应用中应使用适当的转换函数
    r = Rotation.from_matrix(rotation_matrix.numpy())
    euler_angles = r.as_euler('YXY')
    rotation_params = torch.tensor([euler_angles], dtype=torch.float32)
    
    # 应用旋转到原始球谐函数值
    transformed_sh_values = rotation_module(rotation_params, original_sh_values)
    
    # 计算误差
    error = torch.norm(rotated_sh_values - transformed_sh_values) / torch.norm(rotated_sh_values)
    
    print(f"相对误差: {error.item():.6f}")
    print("如果误差接近零,则验证了旋转等变性")
    
    return error.item()

这个函数通过以下步骤验证球谐函数的旋转等变性:

  1. 计算一组随机单位向量的球谐函数值
  2. 将这些向量旋转,然后计算旋转后向量的球谐函数值
  3. 使用 Wigner D-矩阵(通过 cuEquivariance 的旋转功能实现)将原始球谐函数值变换
  4. 比较直接计算的旋转后球谐函数值与理论变换值之间的误差

如果误差接近零,则验证了球谐函数的旋转等变性。

7.2 构建等变神经网络模型

在本节中,我们将展示如何使用 cuEquivariance 构建完整的等变神经网络模型,用于解决实际问题。我们将以分子性质预测为例,这是等变神经网络的一个重要应用领域。

7.2.1 分子性质预测模型

分子性质预测是药物发现和材料设计中的关键任务。由于分子的性质不应依赖于它们在空间中的方向,等变神经网络是这类问题的理想选择。以下是一个用于预测分子性质的等变神经网络模型:

import torch
import torch.nn as nn
import cuequivariance as cue
import cuequivariance_torch as cuet

class EquivariantMolecularModel(nn.Module):
    """
    用于分子性质预测的等变神经网络模型
    """
    def __init__(self, num_atom_types, max_l=2, hidden_channels=64):
        """
        初始化模型
        
        参数:
            num_atom_types: 原子类型的数量
            max_l: 最大角动量量子数
            hidden_channels: 隐藏层通道数
        """
        super().__init__()
        
        # 定义表示
        # 原子特征:标量特征(原子类型嵌入)
        self.irreps_atom = cue.Irreps("SO3", f"{num_atom_types}x0")
        
        # 位置编码:使用球谐函数,包括 l=0,1,...,max_l
        l_values = list(range(max_l + 1))
        sh_dims = [(2*l+1) for l in l_values]
        sh_irreps_str = " + ".join([f"1x{l}" for l in l_values])
        self.irreps_sh = cue.Irreps("SO3", sh_irreps_str)
        
        # 隐藏层表示:包括标量和向量特征
        self.irreps_hidden = cue.Irreps("SO3", f"{hidden_channels}x0 + {hidden_channels//2}x1 + {hidden_channels//4}x2")
        
        # 输出表示:标量(分子性质)
        self.irreps_out = cue.Irreps("SO3", "1x0")
        
        # 创建球谐函数描述符
        self.sh_desc = cue.descriptors.spherical_harmonics(cue.SO3(1), l_values)
        self.sh = cuet.EquivariantTensorProduct(
            self.sh_desc, 
            layout=cue.ir_mul,
            use_fallback=True
        )
        self.sh_weights = nn.Parameter(torch.ones(1, self.sh_desc.weight_numel))
        
        # 创建原子嵌入层(将原子类型映射到标量特征)
        self.atom_embedding = nn.Embedding(num_atom_types, num_atom_types)
        
        # 创建消息传递层
        # 第一层:组合原子特征和相对位置编码
        self.message1_desc = cue.descriptors.tensor_product(
            self.irreps_atom, self.irreps_sh, self.irreps_hidden
        )
        self.message1 = cuet.EquivariantTensorProduct(
            self.message1_desc, 
            layout=cue.ir_mul,
            use_fallback=True
        )
        self.message1_weights = nn.Parameter(torch.randn(1, self.message1_desc.weight_numel))
        
        # 第二层:更新原子特征
        self.update1_desc = cue.descriptors.linear(
            self.irreps_hidden, self.irreps_hidden
        )
        self.update1 = cuet.EquivariantTensorProduct(
            self.update1_desc, 
            layout=cue.ir_mul,
            use_fallback=True
        )
        self.update1_weights = nn.Parameter(torch.randn(1, self.update1_desc.weight_numel))
        
        # 第三层:再次组合原子特征和相对位置编码
        self.message2_desc = cue.descriptors.tensor_product(
            self.irreps_hidden, self.irreps_sh, self.irreps_hidden
        )
        self.message2 = cuet.EquivariantTensorProduct(
            self.message2_desc, 
            layout=cue.ir_mul,
            use_fallback=True
        )
        self.message2_weights = nn.Parameter(torch.randn(1, self.message2_desc.weight_numel))
        
        # 第四层:最终更新原子特征
        self.update2_desc = cue.descriptors.linear(
            self.irreps_hidden, self.irreps_hidden
        )
        self.update2 = cuet.EquivariantTensorProduct(
            self.update2_desc, 
            layout=cue.ir_mul,
            use_fallback=True
        )
        self.update2_weights = nn.Parameter(torch.randn(1, self.update2_desc.weight_numel))
        
        # 输出层:将原子特征聚合为分子特征,然后预测性质
        self.output_desc = cue.descriptors.linear(
            self.irreps_hidden, self.irreps_out
        )
        self.output = cuet.EquivariantTensorProduct(
            self.output_desc, 
            layout=cue.ir_mul,
            use_fallback=True
        )
        self.output_weights = nn.Parameter(torch.randn(1, self.output_desc.weight_numel))
    
    def forward(self, atomic_numbers, positions, edge_index):
        """
        前向传播
        
        参数:
            atomic_numbers: 原子序数,形状为 [num_atoms]
            positions: 原子坐标,形状为 [num_atoms, 3]
            edge_index: 边索引,形状为 [2, num_edges]
            
        返回:
            预测的分子性质(标量)
        """
        # 获取原子数量
        num_atoms = atomic_numbers.shape[0]
        
        # 将原子序数转换为原子类型嵌入
        atom_features = self.atom_embedding(atomic_numbers)  # [num_atoms, num_atom_types]
        
        # 计算边的相对位置
        src, dst = edge_index
        rel_pos = positions[dst] - positions[src]  # [num_edges, 3]
        
        # 计算边的距离
        distances = torch.norm(rel_pos, dim=1, keepdim=True)  # [num_edges, 1]
        
        # 归一化相对位置为单位向量
        unit_vectors = rel_pos / distances  # [num_edges, 3]
        
        # 计算球谐函数值(位置编码)
        sh_values = self.sh(self.sh_weights, unit_vectors)  # [num_edges, irreps_sh.dim]
        
        # 使用径向函数调制球谐函数值
        # 这里使用简单的高斯径向函数
        radial_function = torch.exp(-(distances - 1.5)**2 / 0.5)  # [num_edges, 1]
        sh_values = sh_values * radial_function  # [num_edges, irreps_sh.dim]
        
        # 第一层消息传递
        # 对于每条边,组合源原子特征和位置编码
        edge_messages = self.message1(
            self.message1_weights, 
            atom_features[src], 
            sh_values
        )  # [num_edges, irreps_hidden.dim]
        
        # 聚合消息到目标原子
        atom_messages = torch.zeros(num_atoms, self.irreps_hidden.dim, device=edge_index.device)
        atom_messages.index_add_(0, dst, edge_messages)
        
        # 更新原子特征
        atom_features_hidden = self.update1(
            self.update1_weights, 
            atom_messages
        )  # [num_atoms, irreps_hidden.dim]
        
        # 应用非线性(只对标量部分)
        scalar_dim = sum(mul * ir.dim for mul, ir in self.irreps_hidden if ir.l == 0)
        if scalar_dim > 0:
            atom_features_scalar = atom_features_hidden[:, :scalar_dim]
            atom_features_vector = atom_features_hidden[:, scalar_dim:]
            atom_features_scalar = torch.relu(atom_features_scalar)
            atom_features_hidden = torch.cat([atom_features_scalar, atom_features_vector], dim=1)
        
        # 第二层消息传递
        edge_messages2 = self.message2(
            self.message2_weights, 
            atom_features_hidden[src], 
            sh_values
        )  # [num_edges, irreps_hidden.dim]
        
        # 聚合消息到目标原子
        atom_messages2 = torch.zeros(num_atoms, self.irreps_hidden.dim, device=edge_index.device)
        atom_messages2.index_add_(0, dst, edge_messages2)
        
        # 最终更新原子特征
        atom_features_final = self.update2(
            self.update2_weights, 
            atom_messages2
        )  # [num_atoms, irreps_hidden.dim]
        
        # 应用非线性(只对标量部分)
        if scalar_dim > 0:
            atom_features_scalar = atom_features_final[:, :scalar_dim]
            atom_features_vector = atom_features_final[:, scalar_dim:]
            atom_features_scalar = torch.relu(atom_features_scalar)
            atom_features_final = torch.cat([atom_features_scalar, atom_features_vector], dim=1)
        
        # 聚合所有原子的特征(求和池化)
        molecular_features = atom_features_final.sum(dim=0, keepdim=True)  # [1, irreps_hidden.dim]
        
        # 预测分子性质
        prediction = self.output(
            self.output_weights, 
            molecular_features
        )  # [1, 1]
        
        return prediction.squeeze()

这个模型实现了一个用于分子性质预测的等变消息传递神经网络。它包括以下关键组件:

  1. 原子表示:使用嵌入层将原子类型转换为标量特征
  2. 位置编码:使用球谐函数将原子间相对位置编码为等变特征
  3. 消息传递:通过等变张量积组合原子特征和位置编码,然后在分子图上传递消息
  4. 特征更新:使用等变线性层更新原子特征
  5. 全局池化:将所有原子特征聚合为分子特征
  6. 输出预测:使用等变线性层预测最终的分子性质(标量)

7.2.2 训练和评估模型

以下是如何训练和评估上述分子性质预测模型的示例:

def train_molecular_model(model, dataset, num_epochs=100, lr=0.001):
    """
    训练分子性质预测模型
    
    参数:
        model: 等变分子模型
        dataset: 分子数据集,每个元素包含 (atomic_numbers, positions, edge_index, target)
        num_epochs: 训练轮数
        lr: 学习率
        
    返回:
        训练好的模型
    """
    # 设置优化器
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    
    # 设置损失函数
    loss_fn = torch.nn.MSELoss()
    
    # 训练循环
    for epoch in range(num_epochs):
        total_loss = 0.0
        num_batches = 0
        
        # 遍历数据集
        for atomic_numbers, positions, edge_index, target in dataset:
            # 前向传播
            prediction = model(atomic_numbers, positions, edge_index)
            
            # 计算损失
            loss = loss_fn(prediction, target)
            
            # 反向传播和优化
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            # 累计损失
            total_loss += loss.item()
            num_batches += 1
        
        # 计算平均损失
        avg_loss = total_loss / num_batches
        
        # 打印进度
        if (epoch + 1) % 10 == 0:
            print(f"Epoch {epoch+1}/{num_epochs}, Loss: {avg_loss:.6f}")
    
    return model

def evaluate_molecular_model(model, test_dataset):
    """
    评估分子性质预测模型
    
    参数:
        model: 训练好的等变分子模型
        test_dataset: 测试数据集
        
    返回:
        平均绝对误差和均方根误差
    """
    model.eval()
    
    # 初始化指标
    mae_sum = 0.0
    mse_sum = 0.0
    num_samples = 0
    
    # 禁用梯度计算
    with torch.no_grad():
        # 遍历测试数据集
        for atomic_numbers, positions, edge_index, target in test_dataset:
            # 前向传播
            prediction = model(atomic_numbers, positions, edge_index)
            
            # 计算误差
            mae = torch.abs(prediction - target).item()
            mse = ((prediction - target) ** 2).item()
            
            # 累计误差
            mae_sum += mae
            mse_sum += mse
            num_samples += 1
    
    # 计算平均误差
    mae_avg = mae_sum / num_samples
    rmse_avg = (mse_sum / num_samples) ** 0.5
    
    print(f"评估结果 - MAE: {mae_avg:.6f}, RMSE: {rmse_avg:.6f}")
    
    return mae_avg, rmse_avg

7.2.3 验证旋转不变性

等变神经网络的一个关键优势是它们的预测在输入旋转时保持不变。我们可以通过以下实验来验证这一点:

def verify_rotational_invariance(model, molecule_data):
    """
    验证分子性质预测模型的旋转不变性
    
    参数:
        model: 等变分子模型
        molecule_data: 分子数据,包含 (atomic_numbers, positions, edge_index, target)
        
    返回:
        原始预测和旋转后预测之间的相对误差
    """
    # 导入必要的库
    import torch
    from scipy.spatial.transform import Rotation
    
    # 解包分子数据
    atomic_numbers, positions, edge_index, target = molecule_data
    
    # 设置模型为评估模式
    model.eval()
    
    # 禁用梯度计算
    with torch.no_grad():
        # 计算原始预测
        original_prediction = model(atomic_numbers, positions, edge_index)
        
        # 创建一个随机旋转
        rotation = Rotation.random()
        rotation_matrix = torch.tensor(rotation.as_matrix(), dtype=torch.float32)
        
        # 旋转原子坐标
        rotated_positions = torch.matmul(positions, rotation_matrix.T)
        
        # 计算旋转后的预测
        rotated_prediction = model(atomic_numbers, rotated_positions, edge_index)
        
        # 计算相对误差
        error = torch.abs(original_prediction - rotated_prediction) / (torch.abs(original_prediction) + 1e-10)
        
        print(f"原始预测: {original_prediction.item():.6f}")
        print(f"旋转后预测: {rotated_prediction.item():.6f}")
        print(f"相对误差: {error.item():.6f}")
        
        return error.item()

如果模型具有旋转不变性,那么原始预测和旋转后预测之间的误差应该接近零。这验证了等变神经网络在处理具有旋转对称性的数据时的优势。

7.3 分子动力学中的应用

分子动力学是等变神经网络的一个重要应用领域。在本节中,我们将展示如何使用 cuEquivariance 构建用于分子动力学模拟的等变势能模型。

7.3.1 等变势能模型

在分子动力学中,势能模型用于预测分子构型的能量和力。由于能量是标量(旋转不变的),而力是向量(旋转等变的),等变神经网络是构建物理一致的势能模型的理想选择。

以下是一个使用 cuEquivariance 实现的等变势能模型:

import torch
import torch.nn as nn
import cuequivariance as cue
import cuequivariance_torch as cuet

class EquivariantPotentialModel(nn.Module):
    """
    用于分子动力学的等变势能模型
    """
    def __init__(self, num_atom_types, max_l=2, hidden_channels=64):
        """
        初始化模型
        
        参数:
            num_atom_types: 原子类型的数量
            max_l: 最大角动量量子数
            hidden_channels: 隐藏层通道数
        """
        super().__init__()
        
        # 定义表示
        # 原子特征:标量特征(原子类型嵌入)
        self.irreps_atom = cue.Irreps("SO3", f"{num_atom_types}x0")
        
        # 位置编码:使用球谐函数,包括 l=0,1,...,max_l
        l_values = list(range(max_l + 1))
        sh_dims = [(2*l+1) for l in l_values]
        sh_irreps_str = " + ".join([f"1x{l}" for l in l_values])
        self.irreps_sh = cue.Irreps("SO3", sh_irreps_str)
        
        # 隐藏层表示:包括标量和向量特征
        self.irreps_hidden = cue.Irreps("SO3", f"{hidden_channels}x0 + {hidden_channels//2}x1 + {hidden_channels//4}x2")
        
        # 输出表示:标量(能量)
        self.irreps_out = cue.Irreps("SO3", "1x0")
        
        # 创建球谐函数描述符
        self.sh_desc = cue.descriptors.spherical_harmonics(cue.SO3(1), l_values)
        self.sh = cuet.EquivariantTensorProduct(
            self.sh_desc, 
            layout=cue.ir_mul,
            use_fallback=True
        )
        self.sh_weights = nn.Parameter(torch.ones(1, self.sh_desc.weight_numel))
        
        # 创建原子嵌入层(将原子类型映射到标量特征)
        self.atom_embedding = nn.Embedding(num_atom_types, num_atom_types)
        
        # 创建消息传递层
        # 第一层:组合原子特征和相对位置编码
        self.message1_desc = cue.descriptors.tensor_product(
            self.irreps_atom, self.irreps_sh, self.irreps_hidden
        )
        self.message1 = cuet.EquivariantTensorProduct(
            self.message1_desc, 
            layout=cue.ir_mul,
            use_fallback=True
        )
        self.message1_weights = nn.Parameter(torch.randn(1, self.message1_desc.weight_numel))
        
        # 第二层:更新原子特征
        self.update1_desc = cue.descriptors.linear(
            self.irreps_hidden, self.irreps_hidden
        )
        self.update1 = cuet.EquivariantTensorProduct(
            self.update1_desc, 
            layout=cue.ir_mul,
            use_fallback=True
        )
        self.update1_weights = nn.Parameter(torch.randn(1, self.update1_desc.weight_numel))
        
        # 第三层:再次组合原子特征和相对位置编码
        self.message2_desc = cue.descriptors.tensor_product(
            self.irreps_hidden, self.irreps_sh, self.irreps_hidden
        )
        self.message2 = cuet.EquivariantTensorProduct(
            self.message2_desc, 
            layout=cue.ir_mul,
            use_fallback=True
        )
        self.message2_weights = nn.Parameter(torch.randn(1, self.message2_desc.weight_numel))
        
        # 第四层:最终更新原子特征
        self.update2_desc = cue.descriptors.linear(
            self.irreps_hidden, self.irreps_hidden
        )
        self.update2 = cuet.EquivariantTensorProduct(
            self.update2_desc, 
            layout=cue.ir_mul,
            use_fallback=True
        )
        self.update2_weights = nn.Parameter(torch.randn(1, self.update2_desc.weight_numel))
        
        # 输出层:将原子特征聚合为分子特征,然后预测能量
        self.output_desc = cue.descriptors.linear(
            self.irreps_hidden, self.irreps_out
        )
        self.output = cuet.EquivariantTensorProduct(
            self.output_desc, 
            layout=cue.ir_mul,
            use_fallback=True
        )
        self.output_weights = nn.Parameter(torch.randn(1, self.output_desc.weight_numel))
    
    def forward(self, atomic_numbers, positions, edge_index, compute_forces=False):
        """
        前向传播
        
        参数:
            atomic_numbers: 原子序数,形状为 [num_atoms]
            positions: 原子坐标,形状为 [num_atoms, 3]
            edge_index: 边索引,形状为 [2, num_edges]
            compute_forces: 是否计算力
            
        返回:
            能量(标量)和力(如果 compute_forces=True)
        """
        # 如果需要计算力,需要跟踪位置的梯度
        if compute_forces:
            positions.requires_grad_(True)
        
        # 获取原子数量
        num_atoms = atomic_numbers.shape[0]
        
        # 将原子序数转换为原子类型嵌入
        atom_features = self.atom_embedding(atomic_numbers)  # [num_atoms, num_atom_types]
        
        # 计算边的相对位置
        src, dst = edge_index
        rel_pos = positions[dst] - positions[src]  # [num_edges, 3]
        
        # 计算边的距离
        distances = torch.norm(rel_pos, dim=1, keepdim=True)  # [num_edges, 1]
        
        # 归一化相对位置为单位向量
        unit_vectors = rel_pos / distances  # [num_edges, 3]
        
        # 计算球谐函数值(位置编码)
        sh_values = self.sh(self.sh_weights, unit_vectors)  # [num_edges, irreps_sh.dim]
        
        # 使用径向函数调制球谐函数值
        # 这里使用简单的高斯径向函数
        radial_function = torch.exp(-(distances - 1.5)**2 / 0.5)  # [num_edges, 1]
        sh_values = sh_values * radial_function  # [num_edges, irreps_sh.dim]
        
        # 第一层消息传递
        # 对于每条边,组合源原子特征和位置编码
        edge_messages = self.message1(
            self.message1_weights, 
            atom_features[src], 
            sh_values
        )  # [num_edges, irreps_hidden.dim]
        
        # 聚合消息到目标原子
        atom_messages = torch.zeros(num_atoms, self.irreps_hidden.dim, device=edge_index.device)
        atom_messages.index_add_(0, dst, edge_messages)
        
        # 更新原子特征
        atom_features_hidden = self.update1(
            self.update1_weights, 
            atom_messages
        )  # [num_atoms, irreps_hidden.dim]
        
        # 应用非线性(只对标量部分)
        scalar_dim = sum(mul * ir.dim for mul, ir in self.irreps_hidden if ir.l == 0)
        if scalar_dim > 0:
            atom_features_scalar = atom_features_hidden[:, :scalar_dim]
            atom_features_vector = atom_features_hidden[:, scalar_dim:]
            atom_features_scalar = torch.relu(atom_features_scalar)
            atom_features_hidden = torch.cat([atom_features_scalar, atom_features_vector], dim=1)
        
        # 第二层消息传递
        edge_messages2 = self.message2(
            self.message2_weights, 
            atom_features_hidden[src], 
            sh_values
        )  # [num_edges, irreps_hidden.dim]
        
        # 聚合消息到目标原子
        atom_messages2 = torch.zeros(num_atoms, self.irreps_hidden.dim, device=edge_index.device)
        atom_messages2.index_add_(0, dst, edge_messages2)
        
        # 最终更新原子特征
        atom_features_final = self.update2(
            self.update2_weights, 
            atom_messages2
        )  # [num_atoms, irreps_hidden.dim]
        
        # 应用非线性(只对标量部分)
        if scalar_dim > 0:
            atom_features_scalar = atom_features_final[:, :scalar_dim]
            atom_features_vector = atom_features_final[:, scalar_dim:]
            atom_features_scalar = torch.relu(atom_features_scalar)
            atom_features_final = torch.cat([atom_features_scalar, atom_features_vector], dim=1)
        
        # 预测每个原子的能量贡献
        atom_energies = self.output(
            self.output_weights, 
            atom_features_final
        )  # [num_atoms, 1]
        
        # 计算总能量(所有原子能量的总和)
        total_energy = atom_energies.sum()
        
        # 如果需要,计算力
        if compute_forces:
            # 力是能量对位置的负梯度
            forces = -torch.autograd.grad(
                total_energy, 
                positions, 
                create_graph=True, 
                retain_graph=True
            )[0]
            return total_energy, forces
        else:
            return total_energy

这个模型与前面的分子性质预测模型类似,但有一个关键区别:它可以计算力。力是能量对原子位置的负梯度,通过自动微分计算。由于能量是标量(旋转不变的),力自然是向量(旋转等变的),这保证了模型的物理一致性。

7.3.2 分子动力学模拟

使用上述等变势能模型,我们可以进行分子动力学模拟。以下是一个简单的分子动力学模拟示例:

def molecular_dynamics_simulation(model, atomic_numbers, initial_positions, edge_index, num_steps=1000, dt=0.001, temperature=300):
    """
    使用等变势能模型进行分子动力学模拟
    
    参数:
        model: 等变势能模型
        atomic_numbers: 原子序数
        initial_positions: 初始原子坐标
        edge_index: 边索引
        num_steps: 模拟步数
        dt: 时间步长
        temperature: 模拟温度(K)
        
    返回:
        轨迹(原子坐标随时间的变化)
    """
    # 导入必要的库
    import torch
    import numpy as np
    
    # 设置模型为评估模式
    model.eval()
    
    # 初始化轨迹
    trajectory = [initial_positions.clone().detach()]
    
    # 初始化速度(从麦克斯韦-玻尔兹曼分布采样)
    # 转换温度到适当的单位
    kB = 8.617333262e-5  # 玻尔兹曼常数,单位:eV/K
    mass = torch.ones(atomic_numbers.size(0), 1) * 1.0  # 假设所有原子质量相同,单位:原子质量单位
    velocity_scale = torch.sqrt(kB * temperature / mass)
    velocities = torch.randn_like(initial_positions) * velocity_scale
    
    # 确保总动量为零
    velocities = velocities - velocities.mean(dim=0, keepdim=True)
    
    # 当前位置
    positions = initial_positions.clone()
    
    # 模拟循环
    for step in range(num_steps):
        # 计算能量和力
        energy, forces = model(atomic_numbers, positions, edge_index, compute_forces=True)
        
        # 更新速度(Velocity Verlet 算法的第一步)
        velocities = velocities + 0.5 * dt * forces / mass
        
        # 更新位置
        positions = positions + dt * velocities
        
        # 计算新的力
        _, new_forces = model(atomic_numbers, positions, edge_index, compute_forces=True)
        
        # 更新速度(Velocity Verlet 算法的第二步)
        velocities = velocities + 0.5 * dt * new_forces / mass
        
        # 温度控制(简单的速度重缩放)
        if step % 10 == 0:
            current_temp = torch.mean(mass * torch.sum(velocities**2, dim=1)) / (3 * kB)
            scale_factor = torch.sqrt(temperature / current_temp)
            velocities = velocities * scale_factor
        
        # 保存当前位置到轨迹
        trajectory.append(positions.clone().detach())
        
        # 打印进度
        if (step + 1) % 100 == 0:
            print(f"步骤 {step+1}/{num_steps}, 能量: {energy.item():.6f} eV")
    
    # 将轨迹转换为张量
    trajectory = torch.stack(trajectory)
    
    return trajectory

这个函数实现了一个简单的分子动力学模拟,使用 Velocity Verlet 算法进行时间积分,并使用简单的速度重缩放进行温度控制。它返回分子轨迹,即原子坐标随时间的变化。

7.3.3 分析模拟结果

一旦我们有了分子轨迹,就可以计算各种物理量来分析模拟结果:

def analyze_trajectory(trajectory, atomic_numbers, dt=0.001):
    """
    分析分子动力学轨迹
    
    参数:
        trajectory: 分子轨迹,形状为 [num_steps, num_atoms, 3]
        atomic_numbers: 原子序数
        dt: 时间步长
        
    返回:
        各种分析结果
    """
    # 导入必要的库
    import torch
    import numpy as np
    import matplotlib.pyplot as plt
    
    # 轨迹信息
    num_steps, num_atoms, _ = trajectory.shape
    
    # 计算分子的质心
    # 假设原子质量相同
    center_of_mass = trajectory.mean(dim=1)  # [num_steps, 3]
    
    # 计算均方位移(MSD)
    # 相对于初始位置
    initial_positions = trajectory[0]  # [num_atoms, 3]
    displacements = trajectory - initial_positions.unsqueeze(0)  # [num_steps, num_atoms, 3]
    squared_displacements = torch.sum(displacements**2, dim=2)  # [num_steps, num_atoms]
    msd = squared_displacements.mean(dim=1)  # [num_steps]
    
    # 计算速度自相关函数(VACF)
    # 首先计算速度
    velocities = (trajectory[1:] - trajectory[:-1]) / dt  # [num_steps-1, num_atoms, 3]
    
    # 计算VACF
    initial_velocity = velocities[0]  # [num_atoms, 3]
    vacf = torch.zeros(num_steps - 1)
    for t in range(num_steps - 1):
        # 计算速度点积
        dot_products = torch.sum(velocities[t] * initial_velocity, dim=1)  # [num_atoms]
        vacf[t] = dot_products.mean()
    
    # 归一化VACF
    vacf = vacf / vacf[0]
    
    # 计算键长分布(对于第一个和最后一个原子之间的键)
    # 这只是一个示例,实际应用中应该根据分子的具体结构计算键长
    if num_atoms >= 2:
        bond_lengths = torch.norm(trajectory[:, 1, :] - trajectory[:, 0, :], dim=1)  # [num_steps]
        
        # 绘制键长分布直方图
        plt.figure(figsize=(10, 6))
        plt.hist(bond_lengths.numpy(), bins=50, alpha=0.7)
        plt.xlabel('键长 (Å)')
        plt.ylabel('频率')
        plt.title('键长分布')
        plt.grid(True, alpha=0.3)
        plt.savefig('bond_length_distribution.png')
        plt.close()
    
    # 绘制MSD
    plt.figure(figsize=(10, 6))
    time = np.arange(num_steps) * dt
    plt.plot(time, msd.numpy())
    plt.xlabel('时间 (ps)')
    plt.ylabel('MSD (Ų)')
    plt.title('均方位移')
    plt.grid(True, alpha=0.3)
    plt.savefig('mean_squared_displacement.png')
    plt.close()
    
    # 绘制VACF
    plt.figure(figsize=(10, 6))
    time = np.arange(num_steps - 1) * dt
    plt.plot(time, vacf.numpy())
    plt.xlabel('时间 (ps)')
    plt.ylabel('VACF (归一化)')
    plt.title('速度自相关函数')
    plt.grid(True, alpha=0.3)
    plt.savefig('velocity_autocorrelation.png')
    plt.close()
    
    # 返回分析结果
    results = {
        'msd': msd,
        'vacf': vacf,
        'bond_lengths': bond_lengths if num_atoms >= 2 else None
    }
    
    return results

这个函数计算了几个常见的分子动力学分析量:均方位移(MSD)、速度自相关函数(VACF)和键长分布。这些量提供了关于分子动力学和结构的重要信息。

7.4 3D 点云处理

3D 点云处理是等变神经网络的另一个重要应用领域。在本节中,我们将展示如何使用 cuEquivariance 构建用于点云分类和分割的等变神经网络。

7.4.1 点云分类模型

点云分类是识别点云表示的 3D 对象类别的任务。由于点云的性质不应依赖于它们在空间中的方向,等变神经网络是这类问题的理想选择。

以下是一个使用 cuEquivariance 实现的点云分类模型:

import torch
import torch.nn as nn
import cuequivariance as cue
import cuequivariance_torch as cuet

class EquivariantPointCloudModel(nn.Module):
    """
    用于点云分类的等变神经网络模型
    """
    def __init__(self, num_classes, max_l=2, hidden_channels=64):
        """
        初始化模型
        
        参数:
            num_classes: 类别数量
            max_l: 最大角动量量子数
            hidden_channels: 隐藏层通道数
        """
        super().__init__()
        
        # 定义表示
        # 初始点特征:标量特征(1表示每个点的存在)
        self.irreps_point = cue.Irreps("SO3", "1x0")
        
        # 位置编码:使用球谐函数,包括 l=0,1,...,max_l
        l_values = list(range(max_l + 1))
        sh_irreps_str = " + ".join([f"1x{l}" for l in l_values])
        self.irreps_sh = cue.Irreps("SO3", sh_irreps_str)
        
        # 隐藏层表示:包括标量和向量特征
        self.irreps_hidden1 = cue.Irreps("SO3", f"{hidden_channels}x0 + {hidden_channels//2}x1")
        self.irreps_hidden2 = cue.Irreps("SO3", f"{hidden_channels*2}x0 + {hidden_channels}x1 + {hidden_channels//2}x2")
        
        # 全局特征表示:只有标量(旋转不变)
        self.irreps_global = cue.Irreps("SO3", f"{hidden_channels*4}x0")
        
        # 创建球谐函数描述符
        self.sh_desc = cue.descriptors.spherical_harmonics(cue.SO3(1), l_values)
        self.sh = cuet.EquivariantTensorProduct(
            self.sh_desc, 
            layout=cue.ir_mul,
            use_fallback=True
        )
        self.sh_weights = nn.Parameter(torch.ones(1, self.sh_desc.weight_numel))
        
        # 创建第一层等变卷积
        self.conv1_desc = cue.descriptors.tensor_product(
            self.irreps_point, self.irreps_sh, self.irreps_hidden1
        )
        self.conv1 = cuet.EquivariantTensorProduct(
            self.conv1_desc, 
            layout=cue.ir_mul,
            use_fallback=True
        )
        self.conv1_weights = nn.Parameter(torch.randn(1, self.conv1_desc.weight_numel))
        
        # 创建第二层等变卷积
        self.conv2_desc = cue.descriptors.tensor_product(
            self.irreps_hidden1, self.irreps_sh, self.irreps_hidden2
        )
        self.conv2 = cuet.EquivariantTensorProduct(
            self.conv2_desc, 
            layout=cue.ir_mul,
            use_fallback=True
        )
        self.conv2_weights = nn.Parameter(torch.randn(1, self.conv2_desc.weight_numel))
        
        # 创建全局池化层(将点特征映射到全局特征)
        self.global_desc = cue.descriptors.linear(
            self.irreps_hidden2, self.irreps_global
        )
        self.global_pool = cuet.EquivariantTensorProduct(
            self.global_desc, 
            layout=cue.ir_mul,
            use_fallback=True
        )
        self.global_weights = nn.Parameter(torch.randn(1, self.global_desc.weight_numel))
        
        # 创建分类头(标准MLP)
        self.classifier = nn.Sequential(
            nn.Linear(self.irreps_global.dim, 256),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(128, num_classes)
        )
    
    def forward(self, points):
        """
        前向传播
        
        参数:
            points: 点云坐标,形状为 [batch_size, num_points, 3]
            
        返回:
            类别预测,形状为 [batch_size, num_classes]
        """
        batch_size, num_points, _ = points.shape
        
        # 创建初始点特征(每个点的存在标记为1)
        point_features = torch.ones(batch_size, num_points, self.irreps_point.dim, device=points.device)
        
        # 为每个点计算局部邻居
        # 这里我们使用一个简化的方法,为每个点找到k个最近邻
        k = min(20, num_points - 1)  # 邻居数量
        
        # 计算点之间的距离
        points_flat = points.view(batch_size * num_points, 1, 3)
        points_flat_t = points.view(batch_size, 1, num_points, 3).repeat(1, num_points, 1, 1)
        points_flat_t = points_flat_t.view(batch_size * num_points, num_points, 3)
        
        dist = torch.sum((points_flat - points_flat_t) ** 2, dim=2)  # [batch_size*num_points, num_points]
        
        # 找到k个最近邻(不包括自身)
        dist[:, 0] = float('inf')  # 排除自身
        _, idx = torch.topk(dist, k=k, dim=1, largest=False)  # [batch_size*num_points, k]
        
        # 计算相对位置
        idx_base = torch.arange(0, batch_size, device=points.device).view(-1, 1, 1) * num_points
        idx_base = idx_base.repeat(1, num_points, k)
        idx_base = idx_base.view(-1, k)
        
        idx = idx + idx_base
        idx = idx.view(-1)
        
        points_flat = points.view(batch_size * num_points, 3)
        neighbors = points_flat[idx].view(batch_size * num_points, k, 3)
        
        # 计算相对位置
        rel_pos = neighbors - points_flat.unsqueeze(1)  # [batch_size*num_points, k, 3]
        
        # 计算距离
        distances = torch.norm(rel_pos, dim=2, keepdim=True)  # [batch_size*num_points, k, 1]
        
        # 归一化相对位置为单位向量
        unit_vectors = rel_pos / distances  # [batch_size*num_points, k, 3]
        
        # 计算球谐函数值
        unit_vectors_flat = unit_vectors.view(-1, 3)
        sh_values = self.sh(self.sh_weights, unit_vectors_flat)  # [batch_size*num_points*k, irreps_sh.dim]
        sh_values = sh_values.view(batch_size * num_points, k, -1)
        
        # 使用径向函数调制球谐函数值
        radial_function = torch.exp(-(distances - 0.5)**2 / 0.1)  # [batch_size*num_points, k, 1]
        sh_values = sh_values * radial_function  # [batch_size*num_points, k, irreps_sh.dim]
        
        # 获取邻居的特征
        point_features_flat = point_features.view(batch_size * num_points, -1)
        neighbor_features = point_features_flat[idx].view(batch_size * num_points, k, -1)
        
        # 第一层等变卷积
        # 对每个邻居应用等变张量积
        neighbor_features_flat = neighbor_features.view(-1, self.irreps_point.dim)
        sh_values_flat = sh_values.view(-1, self.irreps_sh.dim)
        
        conv1_out = self.conv1(
            self.conv1_weights.repeat(neighbor_features_flat.shape[0], 1), 
            neighbor_features_flat, 
            sh_values_flat
        )  # [batch_size*num_points*k, irreps_hidden1.dim]
        
        # 聚合邻居特征(求和池化)
        conv1_out = conv1_out.view(batch_size * num_points, k, -1)
        conv1_out = conv1_out.sum(dim=1)  # [batch_size*num_points, irreps_hidden1.dim]
        
        # 应用非线性(只对标量部分)
        scalar_dim1 = sum(mul * ir.dim for mul, ir in self.irreps_hidden1 if ir.l == 0)
        if scalar_dim1 > 0:
            conv1_scalar = conv1_out[:, :scalar_dim1]
            conv1_vector = conv1_out[:, scalar_dim1:]
            conv1_scalar = torch.relu(conv1_scalar)
            conv1_out = torch.cat([conv1_scalar, conv1_vector], dim=1)
        
        # 第二层等变卷积
        # 重新计算邻居特征
        conv1_out_expanded = conv1_out.view(batch_size, num_points, -1)
        conv1_out_flat = conv1_out_expanded.view(batch_size * num_points, -1)
        neighbor_features2 = conv1_out_flat[idx].view(batch_size * num_points, k, -1)
        
        # 对每个邻居应用等变张量积
        neighbor_features2_flat = neighbor_features2.view(-1, self.irreps_hidden1.dim)
        
        conv2_out = self.conv2(
            self.conv2_weights.repeat(neighbor_features2_flat.shape[0], 1), 
            neighbor_features2_flat, 
            sh_values_flat
        )  # [batch_size*num_points*k, irreps_hidden2.dim]
        
        # 聚合邻居特征(求和池化)
        conv2_out = conv2_out.view(batch_size * num_points, k, -1)
        conv2_out = conv2_out.sum(dim=1)  # [batch_size*num_points, irreps_hidden2.dim]
        
        # 应用非线性(只对标量部分)
        scalar_dim2 = sum(mul * ir.dim for mul, ir in self.irreps_hidden2 if ir.l == 0)
        if scalar_dim2 > 0:
            conv2_scalar = conv2_out[:, :scalar_dim2]
            conv2_vector = conv2_out[:, scalar_dim2:]
            conv2_scalar = torch.relu(conv2_scalar)
            conv2_out = torch.cat([conv2_scalar, conv2_vector], dim=1)
        
        # 将点特征转换为全局特征(只保留标量部分,确保旋转不变性)
        conv2_out = conv2_out.view(batch_size, num_points, -1)
        
        # 对每个点应用全局池化
        global_features_per_point = self.global_pool(
            self.global_weights.repeat(batch_size * num_points, 1),
            conv2_out.view(batch_size * num_points, -1)
        )  # [batch_size*num_points, irreps_global.dim]
        
        global_features_per_point = global_features_per_point.view(batch_size, num_points, -1)
        
        # 聚合所有点的特征(最大池化)
        global_features, _ = torch.max(global_features_per_point, dim=1)  # [batch_size, irreps_global.dim]
        
        # 应用分类头
        logits = self.classifier(global_features)  # [batch_size, num_classes]
        
        return logits

这个模型实现了一个用于点云分类的等变神经网络。它包括以下关键组件:

  1. 点特征初始化:每个点初始化为一个简单的标量特征
  2. 局部邻居查找:为每个点找到 k 个最近邻
  3. 球谐函数编码:使用球谐函数将相对位置编码为等变特征
  4. 等变卷积:使用等变张量积实现点云卷积
  5. 全局池化:将点特征聚合为全局特征,只保留标量部分以确保旋转不变性
  6. 分类头:使用标准 MLP 进行最终分类

7.4.2 点云分割模型

点云分割是为点云中的每个点分配语义标签的任务。与点云分类类似,等变神经网络也适用于这个任务。

以下是对上述模型的修改,使其适用于点云分割:

class EquivariantPointCloudSegmentation(EquivariantPointCloudModel):
    """
    用于点云分割的等变神经网络模型
    """
    def __init__(self, num_classes, max_l=2, hidden_channels=64):
        """
        初始化模型
        
        参数:
            num_classes: 类别数量
            max_l: 最大角动量量子数
            hidden_channels: 隐藏层通道数
        """
        # 调用父类初始化
        super().__init__(num_classes, max_l, hidden_channels)
        
        # 替换分类头为分割头
        # 分割头需要为每个点预测类别
        self.segmentation_desc = cue.descriptors.linear(
            self.irreps_hidden2, cue.Irreps("SO3", f"{hidden_channels*2}x0")
        )
        self.segmentation = cuet.EquivariantTensorProduct(
            self.segmentation_desc, 
            layout=cue.ir_mul,
            use_fallback=True
        )
        self.segmentation_weights = nn.Parameter(torch.randn(1, self.segmentation_desc.weight_numel))
        
        # 点级分类器
        self.point_classifier = nn.Sequential(
            nn.Linear(hidden_channels*2, 128),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(128, num_classes)
        )
    
    def forward(self, points):
        """
        前向传播
        
        参数:
            points: 点云坐标,形状为 [batch_size, num_points, 3]
            
        返回:
            点级别的类别预测,形状为 [batch_size, num_points, num_classes]
        """
        batch_size, num_points, _ = points.shape
        
        # 重用父类的大部分实现,但在conv2_out之后分叉
        
        # 创建初始点特征(每个点的存在标记为1)
        point_features = torch.ones(batch_size, num_points, self.irreps_point.dim, device=points.device)
        
        # 为每个点计算局部邻居
        # 这里我们使用一个简化的方法,为每个点找到k个最近邻
        k = min(20, num_points - 1)  # 邻居数量
        
        # 计算点之间的距离
        points_flat = points.view(batch_size * num_points, 1, 3)
        points_flat_t = points.view(batch_size, 1, num_points, 3).repeat(1, num_points, 1, 1)
        points_flat_t = points_flat_t.view(batch_size * num_points, num_points, 3)
        
        dist = torch.sum((points_flat - points_flat_t) ** 2, dim=2)  # [batch_size*num_points, num_points]
        
        # 找到k个最近邻(不包括自身)
        dist[:, 0] = float('inf')  # 排除自身
        _, idx = torch.topk(dist, k=k, dim=1, largest=False)  # [batch_size*num_points, k]
        
        # 计算相对位置
        idx_base = torch.arange(0, batch_size, device=points.device).view(-1, 1, 1) * num_points
        idx_base = idx_base.repeat(1, num_points, k)
        idx_base = idx_base.view(-1, k)
        
        idx = idx + idx_base
        idx = idx.view(-1)
        
        points_flat = points.view(batch_size * num_points, 3)
        neighbors = points_flat[idx].view(batch_size * num_points, k, 3)
        
        # 计算相对位置
        rel_pos = neighbors - points_flat.unsqueeze(1)  # [batch_size*num_points, k, 3]
        
        # 计算距离
        distances = torch.norm(rel_pos, dim=2, keepdim=True)  # [batch_size*num_points, k, 1]
        
        # 归一化相对位置为单位向量
        unit_vectors = rel_pos / distances  # [batch_size*num_points, k, 3]
        
        # 计算球谐函数值
        unit_vectors_flat = unit_vectors.view(-1, 3)
        sh_values = self.sh(self.sh_weights, unit_vectors_flat)  # [batch_size*num_points*k, irreps_sh.dim]
        sh_values = sh_values.view(batch_size * num_points, k, -1)
        
        # 使用径向函数调制球谐函数值
        radial_function = torch.exp(-(distances - 0.5)**2 / 0.1)  # [batch_size*num_points, k, 1]
        sh_values = sh_values * radial_function  # [batch_size*num_points, k, irreps_sh.dim]
        
        # 获取邻居的特征
        point_features_flat = point_features.view(batch_size * num_points, -1)
        neighbor_features = point_features_flat[idx].view(batch_size * num_points, k, -1)
        
        # 第一层等变卷积
        # 对每个邻居应用等变张量积
        neighbor_features_flat = neighbor_features.view(-1, self.irreps_point.dim)
        sh_values_flat = sh_values.view(-1, self.irreps_sh.dim)
        
        conv1_out = self.conv1(
            self.conv1_weights.repeat(neighbor_features_flat.shape[0], 1), 
            neighbor_features_flat, 
            sh_values_flat
        )  # [batch_size*num_points*k, irreps_hidden1.dim]
        
        # 聚合邻居特征(求和池化)
        conv1_out = conv1_out.view(batch_size * num_points, k, -1)
        conv1_out = conv1_out.sum(dim=1)  # [batch_size*num_points, irreps_hidden1.dim]
        
        # 应用非线性(只对标量部分)
        scalar_dim1 = sum(mul * ir.dim for mul, ir in self.irreps_hidden1 if ir.l == 0)
        if scalar_dim1 > 0:
            conv1_scalar = conv1_out[:, :scalar_dim1]
            conv1_vector = conv1_out[:, scalar_dim1:]
            conv1_scalar = torch.relu(conv1_scalar)
            conv1_out = torch.cat([conv1_scalar, conv1_vector], dim=1)
        
        # 第二层等变卷积
        # 重新计算邻居特征
        conv1_out_expanded = conv1_out.view(batch_size, num_points, -1)
        conv1_out_flat = conv1_out_expanded.view(batch_size * num_points, -1)
        neighbor_features2 = conv1_out_flat[idx].view(batch_size * num_points, k, -1)
        
        # 对每个邻居应用等变张量积
        neighbor_features2_flat = neighbor_features2.view(-1, self.irreps_hidden1.dim)
        
        conv2_out = self.conv2(
            self.conv2_weights.repeat(neighbor_features2_flat.shape[0], 1), 
            neighbor_features2_flat, 
            sh_values_flat
        )  # [batch_size*num_points*k, irreps_hidden2.dim]
        
        # 聚合邻居特征(求和池化)
        conv2_out = conv2_out.view(batch_size * num_points, k, -1)
        conv2_out = conv2_out.sum(dim=1)  # [batch_size*num_points, irreps_hidden2.dim]
        
        # 应用非线性(只对标量部分)
        scalar_dim2 = sum(mul * ir.dim for mul, ir in self.irreps_hidden2 if ir.l == 0)
        if scalar_dim2 > 0:
            conv2_scalar = conv2_out[:, :scalar_dim2]
            conv2_vector = conv2_out[:, scalar_dim2:]
            conv2_scalar = torch.relu(conv2_scalar)
            conv2_out = torch.cat([conv2_scalar, conv2_vector], dim=1)
        
        # 从这里开始,分割模型与分类模型不同
        # 为每个点预测类别,而不是为整个点云预测一个类别
        
        # 将点特征转换为分割特征(只保留标量部分,确保旋转不变性)
        segmentation_features = self.segmentation(
            self.segmentation_weights.repeat(batch_size * num_points, 1),
            conv2_out
        )  # [batch_size*num_points, hidden_channels*2]
        
        # 应用点级分类器
        point_logits = self.point_classifier(segmentation_features)  # [batch_size*num_points, num_classes]
        
        # 重塑为 [batch_size, num_points, num_classes]
        point_logits = point_logits.view(batch_size, num_points, -1)
        
        return point_logits

这个模型基于前面的点云分类模型,但做了一个关键修改:它为每个点预测一个类别,而不是为整个点云预测一个类别。这是通过添加一个点级分类头实现的,该分类头将每个点的特征映射到类别预测。

7.4.3 评估旋转等变性

为了验证我们的点云模型的旋转等变性,我们可以进行以下实验:

def verify_point_cloud_equivariance(model, point_cloud):
    """
    验证点云模型的旋转等变性
    
    参数:
        model: 等变点云模型
        point_cloud: 点云数据,形状为 [1, num_points, 3]
        
    返回:
        原始预测和旋转后预测之间的相对误差
    """
    # 导入必要的库
    import torch
    from scipy.spatial.transform import Rotation
    
    # 设置模型为评估模式
    model.eval()
    
    # 禁用梯度计算
    with torch.no_grad():
        # 计算原始预测
        original_prediction = model(point_cloud)
        
        # 创建一个随机旋转
        rotation = Rotation.random()
        rotation_matrix = torch.tensor(rotation.as_matrix(), dtype=torch.float32)
        
        # 旋转点云
        rotated_point_cloud = torch.matmul(point_cloud, rotation_matrix.T)
        
        # 计算旋转后的预测
        rotated_prediction = model(rotated_point_cloud)
        
        # 计算相对误差
        if isinstance(model, EquivariantPointCloudModel):
            # 分类模型:预测是类别概率
            error = torch.norm(original_prediction - rotated_prediction) / torch.norm(original_prediction)
            print(f"分类预测相对误差: {error.item():.6f}")
        else:
            # 分割模型:预测是每个点的类别概率
            error = torch.norm(original_prediction - rotated_prediction) / torch.norm(original_prediction)
            print(f"分割预测相对误差: {error.item():.6f}")
        
        return error.item()

如果模型具有良好的旋转等变性,那么原始预测和旋转后预测之间的误差应该很小。这验证了等变神经网络在处理 3D 点云数据时的优势。

NVIDIA cuEquivariance 详细教程:Beta 特性与实验性功能

8.1 JIT 内核

cuEquivariance 提供了实验性的即时编译(JIT)内核功能,可以根据特定的操作参数动态生成优化的 CUDA 代码。这些 JIT 内核可以在某些情况下提供更好的性能,特别是对于非标准或不常见的表示组合。

8.1.1 启用 JIT 内核

要启用 JIT 内核,可以通过环境变量或在代码中直接设置:

# 方法1:通过环境变量启用
import os
os.environ["CUEQUIVARIANCE_OPS_USE_JIT"] = "1"

# 方法2:在代码中直接设置
import cuequivariance_torch as cuet
cuet.set_use_jit(True)

启用 JIT 内核后,cuEquivariance 将尝试为每个操作生成优化的 CUDA 代码。如果 JIT 编译失败或不支持特定操作,它将自动回退到预编译的内核。

8.1.2 JIT 内核的优势

JIT 内核提供了几个潜在的优势:

  1. 针对特定操作的优化:JIT 内核可以根据特定的表示和操作参数生成高度优化的代码,而不是使用通用的预编译内核。

  2. 支持非标准表示:对于预编译内核可能不支持的不常见表示组合,JIT 内核可以动态生成适当的代码。

  3. 减少内存使用:在某些情况下,JIT 内核可以通过更有效的内存访问模式减少内存使用。

8.1.3 JIT 内核的限制

尽管 JIT 内核提供了潜在的优势,但它们也有一些限制:

  1. 编译开销:第一次使用特定配置时,JIT 编译会引入额外的延迟。

  2. 实验性功能:作为 Beta 特性,JIT 内核可能不如预编译内核稳定。

  3. 系统要求:JIT 编译需要系统上安装了适当的 CUDA 工具链。

8.1.4 JIT 内核性能分析

以下是一个简单的性能分析示例,比较使用和不使用 JIT 内核的性能差异:

import torch
import cuequivariance as cue
import cuequivariance_torch as cuet
import time

def benchmark_jit_kernels():
    """
    比较使用和不使用JIT内核的性能
    """
    # 创建一个复杂的等变张量积描述符
    irreps_in1 = cue.Irreps("SO3", "10x0 + 5x1 + 3x2")
    irreps_in2 = cue.Irreps("SO3", "8x0 + 4x1 + 2x2")
    irreps_out = cue.Irreps("SO3", "15x0 + 10x1 + 5x2")
    
    tensor_product = cue.descriptors.tensor_product(
        irreps_in1, irreps_in2, irreps_out
    )
    
    # 创建输入数据
    batch_size = 32
    x1 = torch.randn(batch_size, irreps_in1.dim, device="cuda")
    x2 = torch.randn(batch_size, irreps_in2.dim, device="cuda")
    weights = torch.randn(1, tensor_product.weight_numel, device="cuda")
    
    # 不使用JIT内核
    cuet.set_use_jit(False)
    module_no_jit = cuet.EquivariantTensorProduct(
        tensor_product, 
        layout=cue.ir_mul,
        use_fallback=True
    ).cuda()
    
    # 预热
    for _ in range(10):
        _ = module_no_jit(weights, x1, x2)
    
    # 计时
    torch.cuda.synchronize()
    start_time = time.time()
    for _ in range(100):
        _ = module_no_jit(weights, x1, x2)
    torch.cuda.synchronize()
    no_jit_time = time.time() - start_time
    
    # 使用JIT内核
    cuet.set_use_jit(True)
    module_jit = cuet.EquivariantTensorProduct(
        tensor_product, 
        layout=cue.ir_mul,
        use_fallback=True
    ).cuda()
    
    # 预热(包括JIT编译时间)
    for _ in range(10):
        _ = module_jit(weights, x1, x2)
    
    # 计时(不包括JIT编译时间)
    torch.cuda.synchronize()
    start_time = time.time()
    for _ in range(100):
        _ = module_jit(weights, x1, x2)
    torch.cuda.synchronize()
    jit_time = time.time() - start_time
    
    print(f"不使用JIT内核: {no_jit_time:.6f} 秒")
    print(f"使用JIT内核: {jit_time:.6f} 秒")
    print(f"加速比: {no_jit_time/jit_time:.2f}x")
    
    return no_jit_time, jit_time

在某些情况下,JIT 内核可以提供显著的性能提升,特别是对于复杂的表示组合。然而,性能提升的程度取决于具体的操作和硬件。

8.2 融合的散射/收集内核

cuEquivariance 还提供了实验性的融合散射/收集内核,这些内核可以优化等变操作中常见的数据重组模式。

8.2.1 散射和收集操作

在等变神经网络中,散射(scatter)和收集(gather)是两种常见的数据重组操作:

  • 散射:将数据从源索引分散到目标索引
  • 收集:从源索引收集数据到目标索引

这些操作在处理图数据、点云或分子结构时特别常见,例如在消息传递神经网络中。

8.2.2 融合内核的优势

融合的散射/收集内核将多个操作合并为单个优化的内核,提供以下优势:

  1. 减少内存访问:通过减少中间结果的存储和加载,降低内存带宽需求。

  2. 提高计算效率:通过更好地利用 GPU 资源,提高计算效率。

  3. 减少内核启动开销:减少 CUDA 内核启动的次数,降低相关开销。

8.2.3 使用融合内核

以下是一个使用融合散射/收集内核的示例:

import torch
import cuequivariance as cue
import cuequivariance_torch as cuet

def use_fused_scatter_gather():
    """
    演示融合散射/收集内核的使用
    """
    # 创建一个图结构
    num_nodes = 1000
    num_edges = 5000
    
    # 随机生成边索引
    edge_index = torch.randint(0, num_nodes, (2, num_edges), device="cuda")
    src, dst = edge_index
    
    # 创建节点特征
    irreps_node = cue.Irreps("SO3", "16x0 + 8x1")
    node_features = torch.randn(num_nodes, irreps_node.dim, device="cuda")
    
    # 创建边特征
    irreps_edge = cue.Irreps("SO3", "4x0 + 2x1")
    edge_features = torch.randn(num_edges, irreps_edge.dim, device="cuda")
    
    # 创建消息传递描述符
    irreps_msg = cue.Irreps("SO3", "8x0 + 4x1")
    message_passing = cue.descriptors.tensor_product(
        irreps_node, irreps_edge, irreps_msg
    )
    
    # 创建消息传递模块(使用融合内核)
    message_module = cuet.EquivariantTensorProduct(
        message_passing, 
        layout=cue.ir_mul,
        use_fallback=True,
        use_fused_kernels=True  # 启用融合内核
    ).cuda()
    
    # 创建权重
    weights = torch.randn(1, message_passing.weight_numel, device="cuda")
    
    # 执行消息传递
    # 1. 收集源节点特征
    src_features = node_features[src]
    
    # 2. 应用等变张量积
    messages = message_module(weights, src_features, edge_features)
    
    # 3. 散射消息到目标节点
    result = torch.zeros(num_nodes, irreps_msg.dim, device="cuda")
    result.index_add_(0, dst, messages)
    
    return result

在这个例子中,我们通过设置 use_fused_kernels=True 启用了融合内核。这将尝试使用优化的内核来执行消息传递操作,包括收集源节点特征、应用等变张量积和散射消息到目标节点。

8.2.4 融合内核的限制

融合内核也有一些限制:

  1. 实验性功能:作为 Beta 特性,融合内核可能不如标准内核稳定。

  2. 有限的操作支持:并非所有操作组合都支持融合。

  3. 内存使用权衡:虽然融合内核可以减少内存访问,但在某些情况下可能需要更多的临时存储空间。

8.3 自定义等变操作

cuEquivariance 提供了灵活的 API 来定义自定义等变操作,允许用户根据特定需求创建专门的等变层。

8.3.1 创建自定义分段张量积

分段张量积(Segmented Tensor Product)是 cuEquivariance 中等变操作的基本构建块。以下是创建自定义分段张量积的示例:

import cuequivariance as cue
import torch

def create_custom_segmented_tensor_product():
    """
    创建自定义分段张量积
    """
    # 创建一个从下标表达式构建的分段张量积
    # 这个表达式表示 C_iu = A_i * B_u
    stp = cue.SegmentedTensorProduct.from_subscripts("i,u,iu")
    
    # 添加段
    # 段定义了如何将输入张量分割成块
    stp.add_segment(0, 0, 10)  # 输入0(A)的第0段包含10个元素
    stp.add_segment(1, 0, 5)   # 输入1(B)的第0段包含5个元素
    stp.add_segment(2, 0, 50)  # 输出(C)的第0段包含50个元素
    
    # 添加路径
    # 路径定义了如何将输入段映射到输出段
    stp.add_path(0, 0, 1, 0, 2, 0)  # 将输入0的段0和输入1的段0映射到输出的段0
    
    # 设置权重维度
    stp.set_weight_dim(1)
    
    return stp

def use_custom_segmented_tensor_product():
    """
    使用自定义分段张量积
    """
    # 创建自定义分段张量积
    stp = create_custom_segmented_tensor_product()
    
    # 创建输入张量
    batch_size = 2
    A = torch.randn(batch_size, 10)  # 输入0,10个元素
    B = torch.randn(batch_size, 5)   # 输入1,5个元素
    
    # 创建权重
    weights = torch.randn(1, stp.weight_numel)
    
    # 在PyTorch中使用自定义分段张量积
    import cuequivariance_torch as cuet
    
    module = cuet.SegmentedTensorProductFunction.apply
    C = module(stp, weights, A, B)
    
    print(f"输入0形状: {A.shape}")
    print(f"输入1形状: {B.shape}")
    print(f"输出形状: {C.shape}")  # 应该是 [batch_size, 50]
    
    return C

这个例子展示了如何创建一个简单的自定义分段张量积,它将两个输入张量相乘并映射到输出张量。

8.3.2 创建自定义等变层

基于自定义分段张量积,我们可以创建自定义等变层:

import torch
import torch.nn as nn
import cuequivariance as cue
import cuequivariance_torch as cuet

class CustomEquivariantLayer(nn.Module):
    """
    自定义等变层
    """
    def __init__(self, irreps_in1, irreps_in2, irreps_out):
        """
        初始化自定义等变层
        
        参数:
            irreps_in1: 第一个输入的不可约表示
            irreps_in2: 第二个输入的不可约表示
            irreps_out: 输出的不可约表示
        """
        super().__init__()
        
        # 创建等变张量积
        self.etp = cue.EquivariantTensorProduct(
            irreps_in=[irreps_in1, irreps_in2],
            irreps_out=irreps_out
        )
        
        # 创建自定义分段张量积
        stp = create_custom_segmented_tensor_product()
        
        # 将分段张量积添加到等变张量积
        self.etp.ds.append(stp)
        
        # 创建PyTorch模块
        self.module = cuet.EquivariantTensorProduct(
            self.etp, 
            layout=cue.ir_mul,
            use_fallback=True
        )
        
        # 创建权重参数
        self.weights = nn.Parameter(torch.randn(1, self.etp.weight_numel))
    
    def forward(self, x1, x2):
        """
        前向传播
        
        参数:
            x1: 第一个输入张量
            x2: 第二个输入张量
            
        返回:
            输出张量
        """
        return self.module(self.weights, x1, x2)

这个自定义等变层使用我们前面定义的自定义分段张量积,并将其包装在 PyTorch 模块中,使其易于在深度学习模型中使用。

8.3.3 自定义等变非线性

等变神经网络中的一个挑战是设计保持等变性的非线性激活函数。以下是一个自定义等变非线性的示例:

import torch
import torch.nn as nn
import cuequivariance as cue
import cuequivariance_torch as cuet

class EquivariantNonlinearity(nn.Module):
    """
    自定义等变非线性激活函数
    """
    def __init__(self, irreps):
        """
        初始化等变非线性
        
        参数:
            irreps: 输入和输出的不可约表示
        """
        super().__init__()
        
        # 分离标量和非标量部分
        self.scalar_indices = []
        self.vector_indices = []
        
        start_idx = 0
        for i, (mul, ir) in enumerate(irreps):
            dim = mul * ir.dim
            if ir.l == 0:  # 标量
                self.scalar_indices.extend(range(start_idx, start_idx + dim))
            else:  # 非标量
                self.vector_indices.extend(range(start_idx, start_idx + dim))
            start_idx += dim
        
        # 创建门控机制的权重
        if len(self.scalar_indices) > 0 and len(self.vector_indices) > 0:
            self.gate_weights = nn.Parameter(torch.ones(len(self.scalar_indices), len(self.vector_indices) // 3))
        else:
            self.gate_weights = None
    
    def forward(self, x):
        """
        前向传播
        
        参数:
            x: 输入张量
            
        返回:
            应用非线性后的张量
        """
        # 创建输出张量
        out = torch.zeros_like(x)
        
        # 应用非线性到标量部分
        if len(self.scalar_indices) > 0:
            scalar_features = x[:, self.scalar_indices]
            out[:, self.scalar_indices] = torch.sigmoid(scalar_features)
        
        # 应用门控非线性到非标量部分
        if len(self.vector_indices) > 0 and len(self.scalar_indices) > 0 and self.gate_weights is not None:
            vector_features = x[:, self.vector_indices]
            
            # 重塑向量特征为 [batch_size, num_vectors, 3]
            num_vectors = len(self.vector_indices) // 3
            vector_features = vector_features.view(x.shape[0], num_vectors, 3)
            
            # 使用标量特征和门控权重计算门控因子
            gates = torch.sigmoid(torch.matmul(scalar_features, self.gate_weights))  # [batch_size, num_vectors]
            
            # 应用门控
            gated_vectors = vector_features * gates.unsqueeze(-1)  # [batch_size, num_vectors, 3]
            
            # 重塑回原始形状
            out[:, self.vector_indices] = gated_vectors.reshape(x.shape[0], -1)
        
        return out

这个自定义等变非线性使用标量特征来门控非标量特征,保持等变性。对于标量部分,我们可以应用任何标准非线性(如 sigmoid),因为标量在旋转下不变。对于非标量部分,我们使用标量特征和可学习的权重来计算门控因子,然后将这些因子应用到非标量特征上。

8.4 性能优化技巧

在使用 cuEquivariance 构建等变神经网络时,有几种技巧可以优化性能。

8.4.1 选择合适的表示

表示的选择对性能有显著影响。一般来说:

  1. 限制高阶表示:高阶表示(如 l=3, l=4 等)的计算成本显著高于低阶表示。除非必要,否则应限制使用高阶表示。

  2. 平衡标量和非标量特征:标量特征(l=0)的计算效率最高,但表达能力有限。向量特征(l=1)提供了良好的表达能力和计算效率的平衡。

  3. 考虑表示的多重性:增加多重性(例如,使用 “10x0” 而不是 “1x0”)通常比增加表示的阶数(例如,添加 l=3, l=4 等)更有效。

8.4.2 批处理和并行化

有效的批处理和并行化可以显著提高性能:

  1. 使用适当的批量大小:太小的批量大小可能无法充分利用 GPU,而太大的批量大小可能导致内存不足。

  2. 利用数据并行:对于大型模型,考虑使用 PyTorch 的 DataParallelDistributedDataParallel 在多个 GPU 上并行训练。

  3. 预取数据:使用 PyTorch 的 DataLoadernum_workerspin_memory 参数优化数据加载。

8.4.3 内存优化

等变神经网络可能比传统网络需要更多的内存。以下是一些内存优化技巧:

  1. 使用混合精度训练:在 PyTorch 中使用 torch.cuda.amp 模块进行混合精度训练,可以减少内存使用并提高性能。

  2. 梯度累积:如果批量大小受内存限制,可以使用梯度累积来模拟更大的批量大小。

  3. 检查点技术:对于非常深的网络,考虑使用检查点技术(如 torch.utils.checkpoint)来减少内存使用,以牺牲一些计算时间为代价。

8.4.4 编译优化

编译优化可以提高运行时性能:

  1. 使用 JIT 内核:如前所述,启用 JIT 内核可以在某些情况下提供性能提升。

  2. TorchScript:考虑使用 TorchScript 编译模型,这可以在某些情况下提高性能。

  3. CUDA 图:对于固定大小的输入,考虑使用 CUDA 图来减少 CUDA 内核启动的开销。

8.4.5 性能分析和调试

定期分析模型性能可以帮助识别瓶颈:

  1. 使用 PyTorch Profiler:PyTorch 提供了内置的分析器,可以帮助识别计算和内存瓶颈。

  2. 监控 GPU 使用:使用工具如 nvidia-smi 或 PyTorch 的 torch.cuda.memory_allocated() 监控 GPU 内存使用。

  3. 分析每层性能:分析每个等变层的性能,识别最耗时的操作。

以下是一个使用 PyTorch Profiler 分析等变网络性能的示例:

import torch
from torch.profiler import profile, record_function, ProfilerActivity
import cuequivariance as cue
import cuequivariance_torch as cuet

def profile_equivariant_network():
    """
    使用PyTorch Profiler分析等变网络性能
    """
    # 创建一个简单的等变网络
    irreps_in = cue.Irreps("SO3", "10x0 + 5x1")
    irreps_hidden = cue.Irreps("SO3", "20x0 + 10x1")
    irreps_out = cue.Irreps("SO3", "5x0 + 3x1")
    
    # 创建等变线性层
    linear1_desc = cue.descriptors.linear(irreps_in, irreps_hidden)
    linear1 = cuet.EquivariantTensorProduct(
        linear1_desc, 
        layout=cue.ir_mul,
        use_fallback=True
    ).cuda()
    
    linear2_desc = cue.descriptors.linear(irreps_hidden, irreps_out)
    linear2 = cuet.EquivariantTensorProduct(
        linear2_desc, 
        layout=cue.ir_mul,
        use_fallback=True
    ).cuda()
    
    # 创建权重
    weights1 = torch.randn(1, linear1_desc.weight_numel, device="cuda")
    weights2 = torch.randn(1, linear2_desc.weight_numel, device="cuda")
    
    # 创建输入数据
    batch_size = 32
    x = torch.randn(batch_size, irreps_in.dim, device="cuda")
    
    # 使用PyTorch Profiler分析性能
    with profile(
        activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
        record_shapes=True,
        profile_memory=True
    ) as prof:
        # 第一层
        with record_function("linear1"):
            h = linear1(weights1, x)
        
        # 非线性(只应用于标量部分)
        with record_function("nonlinearity"):
            scalar_dim = sum(mul * ir.dim for mul, ir in linear1_desc.irreps_out if ir.l == 0)
            if scalar_dim > 0:
                h_scalar = h[:, :scalar_dim]
                h_vector = h[:, scalar_dim:]
                h_scalar = torch.relu(h_scalar)
                h = torch.cat([h_scalar, h_vector], dim=1)
        
        # 第二层
        with record_function("linear2"):
            y = linear2(weights2, h)
    
    # 打印分析结果
    print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
    
    # 导出Chrome跟踪文件以进行可视化
    prof.export_chrome_trace("equivariant_network_trace.json")
    
    return prof

通过这些性能优化技巧,可以显著提高等变神经网络的训练和推理效率,使其在实际应用中更加实用。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

扫地的小何尚

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值