Lucidrains 系列项目源码解析(八十九)

.\lucidrains\se3-transformer-pytorch\se3_transformer_pytorch\__init__.py

# 从 se3_transformer_pytorch 库中导入 SE3Transformer 类
from se3_transformer_pytorch.se3_transformer_pytorch import SE3Transformer

.\lucidrains\se3-transformer-pytorch\setup.py

# 导入设置工具和查找包的函数
from setuptools import setup, find_packages

# 设置包的元数据
setup(
  name = 'se3-transformer-pytorch',  # 包的名称
  packages = find_packages(),  # 查找所有包
  include_package_data = True,  # 包含所有数据文件
  version = '0.9.0',  # 版本号
  license='MIT',  # 许可证
  description = 'SE3 Transformer - Pytorch',  # 描述
  author = 'Phil Wang',  # 作者
  author_email = 'lucidrains@gmail.com',  # 作者邮箱
  url = 'https://github.com/lucidrains/se3-transformer-pytorch',  # 项目链接
  keywords = [  # 关键词列表
    'artificial intelligence',
    'attention mechanism',
    'transformers',
    'equivariance',
    'SE3'
  ],
  install_requires=[  # 安装依赖
    'einops>=0.3',
    'filelock',
    'numpy',
    'torch>=1.6'
  ],
  setup_requires=[  # 设置需要的依赖
    'pytest-runner',
  ],
  tests_require=[  # 测试需要的依赖
    'pytest',
    'lie_learn',
    'numpy',
  ],
  classifiers=[  # 分类器
    'Development Status :: 4 - Beta',
    'Intended Audience :: Developers',
    'Topic :: Scientific/Engineering :: Artificial Intelligence',
    'License :: OSI Approved :: MIT License',
    'Programming Language :: Python :: 3.6',
  ],
)

.\lucidrains\se3-transformer-pytorch\tests\test_basis.py

# 导入 torch 库
import torch
# 从 se3_transformer_pytorch.basis 模块中导入 get_basis, get_R_tensor, basis_transformation_Q_J 函数
from se3_transformer_pytorch.basis import get_basis, get_R_tensor, basis_transformation_Q_J
# 从 se3_transformer_pytorch.irr_repr 模块中导入 irr_repr 函数

# 定义测试函数 test_basis
def test_basis():
    # 设置最大阶数为 3
    max_degree = 3
    # 生成一个形状为 (2, 1024, 3) 的随机张量
    x = torch.randn(2, 1024, 3)
    # 调用 get_basis 函数获取基函数
    basis = get_basis(x, max_degree)
    # 断言基函数字典的长度是否为 (max_degree + 1) 的平方
    assert len(basis.keys()) == (max_degree + 1) ** 2, 'correct number of basis kernels'

# 定义测试函数 test_basis_transformation_Q_J
def test_basis_transformation_Q_J():
    # 生成一个形状为 (4, 3) 的随机角度张量
    rand_angles = torch.rand(4, 3)
    # 设置 J, order_out, order_in 的值为 1
    J, order_out, order_in = 1, 1, 1
    # 调用 basis_transformation_Q_J 函数获取变换矩阵 Q_J,并转换为浮点型
    Q_J = basis_transformation_Q_J(J, order_in, order_out).float()
    # 断言对于随机角度中的每个角度 (a, b, c),基函数变换矩阵和不可约表示矩阵的乘积是否与 Q_J 和不可约表示函数的乘积相近
    assert all(torch.allclose(get_R_tensor(order_out, order_in, a, b, c) @ Q_J, Q_J @ irr_repr(J, a, b, c)) for a, b, c in rand_angles)

.\lucidrains\se3-transformer-pytorch\tests\test_equivariance.py

import torch
from se3_transformer_pytorch.se3_transformer_pytorch import SE3Transformer
from se3_transformer_pytorch.irr_repr import rot
from se3_transformer_pytorch.utils import torch_default_dtype, fourier_encode

# 测试普通 SE3Transformer 模型
def test_transformer():
    model = SE3Transformer(
        dim = 64,
        depth = 1,
        num_degrees = 2,
        num_neighbors = 4,
        valid_radius = 10
    )

    feats = torch.randn(1, 32, 64)
    coors = torch.randn(1, 32, 3)
    mask  = torch.ones(1, 32).bool()

    out = model(feats, coors, mask, return_type = 0)
    assert out.shape == (1, 32, 64), 'output must be of the right shape'

# 测试有因果性的 SE3Transformer 模型
def test_causal_se3_transformer():
    model = SE3Transformer(
        dim = 64,
        depth = 1,
        num_degrees = 2,
        num_neighbors = 4,
        valid_radius = 10,
        causal = True
    )

    feats = torch.randn(1, 32, 64)
    coors = torch.randn(1, 32, 3)
    mask  = torch.ones(1, 32).bool()

    out = model(feats, coors, mask, return_type = 0)
    assert out.shape == (1, 32, 64), 'output must be of the right shape'

# 测试带全局节点的 SE3Transformer 模型
def test_se3_transformer_with_global_nodes():
    model = SE3Transformer(
        dim = 64,
        depth = 1,
        num_degrees = 2,
        num_neighbors = 4,
        valid_radius = 10,
        global_feats_dim = 16
    )

    feats = torch.randn(1, 32, 64)
    coors = torch.randn(1, 32, 3)
    mask  = torch.ones(1, 32).bool()

    global_feats = torch.randn(1, 2, 16)

    out = model(feats, coors, mask, return_type = 0, global_feats = global_feats)
    assert out.shape == (1, 32, 64), 'output must be of the right shape'

# 测试带单头键值对的 SE3Transformer 模型
def test_one_headed_key_values_se3_transformer_with_global_nodes():
    model = SE3Transformer(
        dim = 64,
        depth = 1,
        num_degrees = 2,
        num_neighbors = 4,
        valid_radius = 10,
        global_feats_dim = 16,
        one_headed_key_values = True
    )

    feats = torch.randn(1, 32, 64)
    coors = torch.randn(1, 32, 3)
    mask  = torch.ones(1, 32).bool()

    global_feats = torch.randn(1, 2, 16)

    out = model(feats, coors, mask, return_type = 0, global_feats = global_feats)
    assert out.shape == (1, 32, 64), 'output must be of the right shape'

# 测试带边的 SE3Transformer 模型
def test_transformer_with_edges():
    model = SE3Transformer(
        dim = 64,
        depth = 1,
        num_degrees = 2,
        num_neighbors = 4,
        edge_dim = 4,
        num_edge_tokens = 4
    )

    feats = torch.randn(1, 32, 64)
    edges = torch.randint(0, 4, (1, 32))
    coors = torch.randn(1, 32, 3)
    mask  = torch.ones(1, 32).bool()

    out = model(feats, coors, mask, edges = edges, return_type = 0)
    assert out.shape == (1, 32, 64), 'output must be of the right shape'

# 测试带连续边的 SE3Transformer 模型
def test_transformer_with_continuous_edges():
    model = SE3Transformer(
        dim = 64,
        depth = 1,
        attend_self = True,
        num_degrees = 2,
        output_degrees = 2,
        edge_dim = 34
    )

    feats = torch.randn(1, 32, 64)
    coors = torch.randn(1, 32, 3)
    mask  = torch.ones(1, 32).bool()

    pairwise_continuous_values = torch.randint(0, 4, (1, 32, 32, 2))

    edges = fourier_encode(
        pairwise_continuous_values,
        num_encodings = 8,
        include_self = True
    )

    out = model(feats, coors, mask, edges = edges, return_type = 1)
    assert True

# 测试不同输入维度的 SE3Transformer 模型
def test_different_input_dimensions_for_types():
    model = SE3Transformer(
        dim_in = (4, 2),
        dim = 4,
        depth = 1,
        input_degrees = 2,
        num_degrees = 2,
        output_degrees = 2,
        reduce_dim_out = True
    )

    atom_feats  = torch.randn(2, 32, 4, 1)
    coors_feats = torch.randn(2, 32, 2, 3)

    features = {'0': atom_feats, '1': coors_feats}
    coors = torch.randn(2, 32, 3)
    mask  = torch.ones(2, 32).bool()

    refined_coors = coors + model(features, coors, mask, return_type = 1)
    assert True

# 测试等变性
def test_equivariance():
    # 创建一个 SE3Transformer 模型对象,设置参数:维度为64,深度为1,自我关注为True,邻居数量为4,角度数量为2,输出角度数量为2,距离进行傅立叶编码为True
    model = SE3Transformer(
        dim = 64,
        depth = 1,
        attend_self = True,
        num_neighbors = 4,
        num_degrees = 2,
        output_degrees = 2,
        fourier_encode_dist = True
    )

    # 生成一个大小为(1, 32, 64)的随机张量作为特征
    feats = torch.randn(1, 32, 64)
    # 生成一个大小为(1, 32, 3)的随机张量作为坐标
    coors = torch.randn(1, 32, 3)
    # 生成一个大小为(1, 32)的全为True的布尔张量作为掩码
    mask  = torch.ones(1, 32).bool()

    # 生成一个旋转矩阵 R,旋转角度为(15, 0, 45)
    R   = rot(15, 0, 45)
    # 使用模型对特征、经过旋转后的坐标、掩码进行前向传播,返回类型为1
    out1 = model(feats, coors @ R, mask, return_type = 1)
    # 使用模型对特征、原始坐标、掩码进行前向传播,返回类型为1,然后再乘以旋转矩阵 R
    out2 = model(feats, coors, mask, return_type = 1) @ R

    # 计算两个输出之间的最大差异
    diff = (out1 - out2).max()
    # 断言差异小于1e-4,如果不成立则抛出异常 'is not equivariant'
    assert diff < 1e-4, 'is not equivariant'
# 测试具有 EGNN 骨干的等变性
def test_equivariance_with_egnn_backbone():
    # 创建 SE3Transformer 模型
    model = SE3Transformer(
        dim = 64,
        depth = 1,
        attend_self = True,
        num_neighbors = 4,
        num_degrees = 2,
        output_degrees = 2,
        fourier_encode_dist = True,
        use_egnn = True
    )

    # 生成随机特征、坐标和掩码
    feats = torch.randn(1, 32, 64)
    coors = torch.randn(1, 32, 3)
    mask  = torch.ones(1, 32).bool()

    # 生成旋转矩阵
    R   = rot(15, 0, 45)
    # 使用旋转后的坐标进行模型推理
    out1 = model(feats, coors @ R, mask, return_type = 1)
    # 使用旋转后的特征进行模型推理,然后再旋转输出
    out2 = model(feats, coors, mask, return_type = 1) @ R

    # 计算输出之间的差异
    diff = (out1 - out2).max()
    # 断言输出的差异小于给定阈值
    assert diff < 1e-4, 'is not equivariant'

# 测试旋转
def test_rotary():
    # 创建 SE3Transformer 模型
    model = SE3Transformer(
        dim = 64,
        depth = 1,
        attend_self = True,
        num_neighbors = 4,
        num_degrees = 2,
        output_degrees = 2,
        fourier_encode_dist = True,
        rotary_position = True,
        rotary_rel_dist = True
    )

    # 生成随机特征、坐标和掩码
    feats = torch.randn(1, 32, 64)
    coors = torch.randn(1, 32, 3)
    mask  = torch.ones(1, 32).bool()

    # 生成旋转矩阵
    R   = rot(15, 0, 45)
    # 使用旋转后的坐标进行模型推理
    out1 = model(feats, coors @ R, mask, return_type = 1)
    # 使用旋转后的特征进行模型推理,然后再旋转输出
    out2 = model(feats, coors, mask, return_type = 1) @ R

    # 计算输出之间的差异
    diff = (out1 - out2).max()
    # 断言输出的差异小于给定阈值
    assert diff < 1e-4, 'is not equivariant'

# 测试等变性线性投影键
def test_equivariance_linear_proj_keys():
    # 创建 SE3Transformer 模型
    model = SE3Transformer(
        dim = 64,
        depth = 1,
        attend_self = True,
        num_neighbors = 4,
        num_degrees = 2,
        output_degrees = 2,
        fourier_encode_dist = True,
        linear_proj_keys = True
    )

    # 生成随机特征、坐标和掩码
    feats = torch.randn(1, 32, 64)
    coors = torch.randn(1, 32, 3)
    mask  = torch.ones(1, 32).bool()

    # 生成旋转矩阵
    R   = rot(15, 0, 45)
    # 使用旋转后的坐标进行模型推理
    out1 = model(feats, coors @ R, mask, return_type = 1)
    # 使用旋转后的特征进行模型推理,然后再旋转输出
    out2 = model(feats, coors, mask, return_type = 1) @ R

    # 计算输出之间的差异
    diff = (out1 - out2).max()
    # 断言输出的差异小于给定阈值
    assert diff < 1e-4, 'is not equivariant'

# 测试仅稀疏邻居的等变性
@torch_default_dtype(torch.float64)
def test_equivariance_only_sparse_neighbors():
    # 创建 SE3Transformer 模型
    model = SE3Transformer(
        dim = 64,
        depth = 1,
        attend_self = True,
        num_degrees = 2,
        output_degrees = 2,
        num_neighbors = 0,
        attend_sparse_neighbors = True,
        num_adj_degrees = 2,
        adj_dim = 4
    )

    # 生成随机特征、坐标和掩码
    feats = torch.randn(1, 32, 64)
    coors = torch.randn(1, 32, 3)
    mask  = torch.ones(1, 32).bool()

    # 生成邻接矩阵
    seq = torch.arange(32)
    adj_mat = (seq[:, None] >= (seq[None, :] - 1)) & (seq[:, None] <= (seq[None, :] + 1))

    # 生成旋转矩阵
    R   = rot(15, 0, 45)
    # 使用旋转后的坐标和邻接矩阵进行模型推理
    out1 = model(feats, coors @ R, mask, adj_mat = adj_mat, return_type = 1)
    # 使用旋转后的特征和邻接矩阵进行模型推理,然后再旋转输出
    out2 = model(feats, coors, mask, adj_mat = adj_mat, return_type = 1) @ R

    # 计算输出之间的差异
    diff = (out1 - out2).max()
    # 断言输出的差异小于给定阈值
    assert diff < 1e-4, 'is not equivariant'

# 测试具有可逆网络的等变性
def test_equivariance_with_reversible_network():
    # 创建 SE3Transformer 模型
    model = SE3Transformer(
        dim = 64,
        depth = 1,
        attend_self = True,
        num_neighbors = 4,
        num_degrees = 2,
        output_degrees = 2,
        reversible = True
    )

    # 生成随机特征、坐标和掩码
    feats = torch.randn(1, 32, 64)
    coors = torch.randn(1, 32, 3)
    mask  = torch.ones(1, 32).bool()

    # 生成旋转矩阵
    R   = rot(15, 0, 45)
    # 使用旋转后的坐标进行模型推理
    out1 = model(feats, coors @ R, mask, return_type = 1)
    # 使用旋转后的特征进行模型推理,然后再旋转输出
    out2 = model(feats, coors, mask, return_type = 1) @ R

    # 计算输出之间的差异
    diff = (out1 - out2).max()
    # 断言输出的差异小于给定阈值
    assert diff < 1e-4, 'is not equivariant'

# 测试具有类型一输入的等变性
def test_equivariance_with_type_one_input():
    # 创建 SE3Transformer 模型
    model = SE3Transformer(
        dim = 64,
        depth = 1,
        attend_self = True,
        num_neighbors = 4,
        num_degrees = 2,
        input_degrees = 2,
        output_degrees = 2
    )

    # 生成随机原子特征和预测坐标
    atom_features = torch.randn(1, 32, 64, 1)
    pred_coors = torch.randn(1, 32, 64, 3)

    # 生成随机坐标和掩码
    coors = torch.randn(1, 32, 3)
    mask  = torch.ones(1, 32).bool()

    # 生成旋转矩阵
    R   = rot(15, 0, 45)
    # 使用旋转后的坐标和预测坐标进行模型推理
    out1 = model({'0': atom_features, '1': pred_coors @ R}, coors @ R, mask, return_type = 1)
    # 使用旋转后的原子特征和预测坐标进行模型推理,然后再旋转输出
    out2 = model({'0': atom_features, '1': pred_coors}, coors, mask, return_type = 1) @ R

    # 计算输出之间的差异
    diff = (out1 - out2).max()
    # 断言输出的差异小于给定阈值
    assert diff < 1e-4, 'is not equivariant'

.\lucidrains\se3-transformer-pytorch\tests\test_irrep_repr.py

# 导入 torch 库
import torch
# 从 se3_transformer_pytorch.spherical_harmonics 模块中导入 clear_spherical_harmonics_cache 函数
from se3_transformer_pytorch.spherical_harmonics import clear_spherical_harmonics_cache
# 从 se3_transformer_pytorch.irr_repr 模块中导入 spherical_harmonics, irr_repr, compose 函数
from se3_transformer_pytorch.irr_repr import spherical_harmonics, irr_repr, compose
# 从 se3_transformer_pytorch.utils 模块中导入 torch_default_dtype 函数
from se3_transformer_pytorch.utils import torch_default_dtype

# 使用 torch.float64 作为默认数据类型
@torch_default_dtype(torch.float64)
# 定义测试函数 test_irr_repr
def test_irr_repr():
    """
    This test tests that
    - irr_repr
    - compose
    - spherical_harmonics
    are compatible

    Y(Z(alpha) Y(beta) Z(gamma) x) = D(alpha, beta, gamma) Y(x)
    with x = Z(a) Y(b) eta
    """
    # 循环遍历阶数范围为 0 到 6
    for order in range(7):
        # 生成两个随机数 a, b
        a, b = torch.rand(2)
        # 生成三个随机数 alpha, beta, gamma
        alpha, beta, gamma = torch.rand(3)

        # 计算 compose(alpha, beta, gamma, a, b, 0) 的结果
        ra, rb, _ = compose(alpha, beta, gamma, a, b, 0)
        # 计算 spherical_harmonics(order, ra, rb) 的结果
        Yrx = spherical_harmonics(order, ra, rb)
        # 清除球谐函数缓存
        clear_spherical_harmonics_cache()

        # 计算 spherical_harmonics(order, a, b) 的结果
        Y = spherical_harmonics(order, a, b)
        # 清除球谐函数缓存
        clear_spherical_harmonics_cache()

        # 计算 irr_repr(order, alpha, beta, gamma) @ Y 的结果
        DrY = irr_repr(order, alpha, beta, gamma) @ Y

        # 计算 (Yrx - DrY).abs().max() 和 Y.abs().max() 的最大值
        d, r = (Yrx - DrY).abs().max(), Y.abs().max()
        # 打印结果
        print(d.item(), r.item())
        # 断言 d < 1e-10 * r,如果不成立则抛出异常
        assert d < 1e-10 * r, d / r

.\lucidrains\se3-transformer-pytorch\tests\test_spherical_harmonics.py

# 导入必要的库
import time
import torch
import numpy as np

# 从 lie_learn 库中导入 spherical_harmonics 函数
from lie_learn.representations.SO3.spherical_harmonics import sh

# 从 se3_transformer_pytorch 库中导入 get_spherical_harmonics_element 和 benchmark 函数
from se3_transformer_pytorch.spherical_harmonics import get_spherical_harmonics_element
from se3_transformer_pytorch.utils import benchmark

# 定义测试 spherical_harmonics 函数
def test_spherical_harmonics():
    # 设置数据类型为 torch.float64
    dtype = torch.float64

    # 生成随机的 theta 和 phi 数据
    theta = 0.1 * torch.randn(32, 1024, 10, dtype=dtype)
    phi = 0.1 * torch.randn(32, 1024, 10, dtype=dtype)

    # 初始化变量
    s0 = s1 = 0
    max_error = -1.

    # 循环遍历 l 和 m 的取值范围
    for l in range(8):
        for m in range(-l, l + 1):
            # 记录开始时间
            start = time.time()

            # 使用 benchmark 函数计算 get_spherical_harmonics_element 函数的运行时间和输出
            diff, y = benchmark(get_spherical_harmonics_element)(l, m, theta, phi)
            # 将 y 转换为 torch.float32 类型
            y = y.type(torch.float32)
            s0 += diff

            # 使用 benchmark 函数计算 sh 函数的运行时间和输出
            diff, z = benchmark(sh)(l, m, theta, phi)
            s1 += diff

            # 计算误差
            error = np.mean(np.abs((y.cpu().numpy() - z) / z))
            max_error = max(max_error, error)
            print(f"l: {l}, m: {m} ", error)

    # 计算时间差异比率
    time_diff_ratio = s0 / s1

    # 断言最大误差小于 1e-4
    assert max_error < 1e-4, 'maximum error must be less than 1e-3'
    # 断言时间差异比率小于 1
    assert time_diff_ratio < 1., 'spherical harmonics must be faster than the one offered by lie_learn'

    # 打印最大误差和时间差异比率
    print(f"Max error: {max_error}")
    print(f"Time diff: {time_diff_ratio}")

Segformer - Pytorch

Implementation of Segformer, Attention + MLP neural network for segmentation, in Pytorch.

Install

$ pip install segformer-pytorch

Usage

For example, MiT-B0

import torch
from segformer_pytorch import Segformer

model = Segformer(
    dims = (32, 64, 160, 256),      # dimensions of each stage
    heads = (1, 2, 5, 8),           # heads of each stage
    ff_expansion = (8, 8, 4, 4),    # feedforward expansion factor of each stage
    reduction_ratio = (8, 4, 2, 1), # reduction ratio of each stage for efficient attention
    num_layers = 2,                 # num layers of each stage
    decoder_dim = 256,              # decoder dimension
    num_classes = 4                 # number of segmentation classes
)

x = torch.randn(1, 3, 256, 256)
pred = model(x) # (1, 4, 64, 64)  # output is (H/4, W/4) map of the number of segmentation classes

Make sure the keywords are at most a tuple of 4, as this repository is hard-coded to give the MiT 4 stages as done in the paper.

Citations

@misc{xie2021segformer,
    title   = {SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers}, 
    author  = {Enze Xie and Wenhai Wang and Zhiding Yu and Anima Anandkumar and Jose M. Alvarez and Ping Luo},
    year    = {2021},
    eprint  = {2105.15203},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}

.\lucidrains\segformer-pytorch\segformer_pytorch\segformer_pytorch.py

# 从 math 模块中导入 sqrt 函数
from math import sqrt
# 从 functools 模块中导入 partial 函数
from functools import partial
# 导入 torch 模块
import torch
# 从 torch 模块中导入 nn、einsum 函数
from torch import nn, einsum
# 从 torch.nn 模块中导入 functional 模块
import torch.nn.functional as F

# 从 einops 模块中导入 rearrange、reduce 函数
from einops import rearrange, reduce
# 从 einops.layers.torch 模块中导入 Rearrange 类

# 定义一个函数,用于判断变量是否存在
def exists(val):
    return val is not None

# 定义一个函数,用于将变量转换为元组
def cast_tuple(val, depth):
    return val if isinstance(val, tuple) else (val,) * depth

# 定义一个类 DsConv2d,继承自 nn.Module 类
class DsConv2d(nn.Module):
    # 初始化方法
    def __init__(self, dim_in, dim_out, kernel_size, padding, stride = 1, bias = True):
        super().__init__()
        # 定义一个神经网络序列
        self.net = nn.Sequential(
            nn.Conv2d(dim_in, dim_in, kernel_size = kernel_size, padding = padding, groups = dim_in, stride = stride, bias = bias),
            nn.Conv2d(dim_in, dim_out, kernel_size = 1, bias = bias)
        )
    # 前向传播方法
    def forward(self, x):
        return self.net(x)

# 定义一个类 LayerNorm,继承自 nn.Module 类
class LayerNorm(nn.Module):
    # 初始化方法
    def __init__(self, dim, eps = 1e-5):
        super().__init__()
        self.eps = eps
        self.g = nn.Parameter(torch.ones(1, dim, 1, 1))
        self.b = nn.Parameter(torch.zeros(1, dim, 1, 1))

    # 前向传播方法
    def forward(self, x):
        std = torch.var(x, dim = 1, unbiased = False, keepdim = True).sqrt()
        mean = torch.mean(x, dim = 1, keepdim = True)
        return (x - mean) / (std + self.eps) * self.g + self.b

# 定义一个类 PreNorm,继承自 nn.Module 类
class PreNorm(nn.Module):
    # 初始化方法
    def __init__(self, dim, fn):
        super().__init__()
        self.fn = fn
        self.norm = LayerNorm(dim)

    # 前向传播方法
    def forward(self, x):
        return self.fn(self.norm(x))

# 定义一个类 EfficientSelfAttention,继承自 nn.Module 类
class EfficientSelfAttention(nn.Module):
    # 初始化方法
    def __init__(
        self,
        *,
        dim,
        heads,
        reduction_ratio
    ):
        super().__init__()
        self.scale = (dim // heads) ** -0.5
        self.heads = heads

        self.to_q = nn.Conv2d(dim, dim, 1, bias = False)
        self.to_kv = nn.Conv2d(dim, dim * 2, reduction_ratio, stride = reduction_ratio, bias = False)
        self.to_out = nn.Conv2d(dim, dim, 1, bias = False)

    # 前向传播方法
    def forward(self, x):
        h, w = x.shape[-2:]
        heads = self.heads

        q, k, v = (self.to_q(x), *self.to_kv(x).chunk(2, dim = 1))
        q, k, v = map(lambda t: rearrange(t, 'b (h c) x y -> (b h) (x y) c', h = heads), (q, k, v))

        sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
        attn = sim.softmax(dim = -1)

        out = einsum('b i j, b j d -> b i d', attn, v)
        out = rearrange(out, '(b h) (x y) c -> b (h c) x y', h = heads, x = h, y = w)
        return self.to_out(out)

# 定义一个类 MixFeedForward,继承自 nn.Module 类
class MixFeedForward(nn.Module):
    # 初始化方法
    def __init__(
        self,
        *,
        dim,
        expansion_factor
    ):
        super().__init__()
        hidden_dim = dim * expansion_factor
        self.net = nn.Sequential(
            nn.Conv2d(dim, hidden_dim, 1),
            DsConv2d(hidden_dim, hidden_dim, 3, padding = 1),
            nn.GELU(),
            nn.Conv2d(hidden_dim, dim, 1)
        )

    # 前向传播方法
    def forward(self, x):
        return self.net(x)

# 定义一个类 MiT,继承自 nn.Module 类
class MiT(nn.Module):
    # 初始化方法
    def __init__(
        self,
        *,
        channels,
        dims,
        heads,
        ff_expansion,
        reduction_ratio,
        num_layers
    # 定义一个继承自 nn.Module 的类
    ):
        # 超类初始化
        super().__init__()
        # 定义每个阶段的卷积核大小、步长和填充
        stage_kernel_stride_pad = ((7, 4, 3), (3, 2, 1), (3, 2, 1), (3, 2, 1))

        # 将通道数和维度组合成一个元组
        dims = (channels, *dims)
        # 将维度两两配对
        dim_pairs = list(zip(dims[:-1], dims[1:]))

        # 初始化阶段列表
        self.stages = nn.ModuleList([])

        # 遍历每个阶段的参数
        for (dim_in, dim_out), (kernel, stride, padding), num_layers, ff_expansion, heads, reduction_ratio in zip(dim_pairs, stage_kernel_stride_pad, num_layers, ff_expansion, heads, reduction_ratio):
            # 创建获取重叠补丁的对象
            get_overlap_patches = nn.Unfold(kernel, stride = stride, padding = padding)
            # 创建重叠补丁嵌入的卷积层
            overlap_patch_embed = nn.Conv2d(dim_in * kernel ** 2, dim_out, 1)

            # 初始化层列表
            layers = nn.ModuleList([])

            # 根据层数循环创建自注意力和前馈网络层
            for _ in range(num_layers):
                layers.append(nn.ModuleList([
                    PreNorm(dim_out, EfficientSelfAttention(dim = dim_out, heads = heads, reduction_ratio = reduction_ratio)),
                    PreNorm(dim_out, MixFeedForward(dim = dim_out, expansion_factor = ff_expansion)),
                ]))

            # 将当前阶段的组件添加到阶段列表中
            self.stages.append(nn.ModuleList([
                get_overlap_patches,
                overlap_patch_embed,
                layers
            ]))

    # 前向传播函数
    def forward(
        self,
        x,
        return_layer_outputs = False
    ):
        # 获取输入张量的高度和宽度
        h, w = x.shape[-2:]

        # 初始化存储每个阶段输出的列表
        layer_outputs = []
        # 遍历每个阶段
        for (get_overlap_patches, overlap_embed, layers) in self.stages:
            # 对输入张量进行重叠补丁提取
            x = get_overlap_patches(x)

            # 计算补丁数量和比例
            num_patches = x.shape[-1]
            ratio = int(sqrt((h * w) / num_patches))
            x = rearrange(x, 'b c (h w) -> b c h w', h = h // ratio)

            # 对补丁进行嵌入
            x = overlap_embed(x)
            # 遍历当前阶段的每一层
            for (attn, ff) in layers:
                x = attn(x) + x
                x = ff(x) + x

            # 将当前阶段的��出添加到列表中
            layer_outputs.append(x)

        # 根据是否返回每个阶段的输出,选择返回值
        ret = x if not return_layer_outputs else layer_outputs
        return ret
class Segformer(nn.Module):
    # 定义 Segformer 类,继承自 nn.Module
    def __init__(
        self,
        *,
        dims = (32, 64, 160, 256),
        heads = (1, 2, 5, 8),
        ff_expansion = (8, 8, 4, 4),
        reduction_ratio = (8, 4, 2, 1),
        num_layers = 2,
        channels = 3,
        decoder_dim = 256,
        num_classes = 4
    ):
        # 初始化函数,接受一系列参数
        super().__init__()
        # 调用父类的初始化函数

        # 将参数转换为长度为4的元组
        dims, heads, ff_expansion, reduction_ratio, num_layers = map(partial(cast_tuple, depth = 4), (dims, heads, ff_expansion, reduction_ratio, num_layers))
        # 断言参数长度为4
        assert all([*map(lambda t: len(t) == 4, (dims, heads, ff_expansion, reduction_ratio, num_layers))]), 'only four stages are allowed, all keyword arguments must be either a single value or a tuple of 4 values'

        # 创建 MiT 模型
        self.mit = MiT(
            channels = channels,
            dims = dims,
            heads = heads,
            ff_expansion = ff_expansion,
            reduction_ratio = reduction_ratio,
            num_layers = num_layers
        )

        # 创建转换到融合层的模块列表
        self.to_fused = nn.ModuleList([nn.Sequential(
            nn.Conv2d(dim, decoder_dim, 1),
            nn.Upsample(scale_factor = 2 ** i)
        ) for i, dim in enumerate(dims)])

        # 创建转换到分割层的模块
        self.to_segmentation = nn.Sequential(
            nn.Conv2d(4 * decoder_dim, decoder_dim, 1),
            nn.Conv2d(decoder_dim, num_classes, 1),
        )

    def forward(self, x):
        # 前向传播函数
        layer_outputs = self.mit(x, return_layer_outputs = True)

        # 对每个输出应用转换到融合层的模块
        fused = [to_fused(output) for output, to_fused in zip(layer_outputs, self.to_fused)]
        # 在通道维度上拼接融合后的输出
        fused = torch.cat(fused, dim = 1)
        # 返回分割层的输出
        return self.to_segmentation(fused)

.\lucidrains\segformer-pytorch\segformer_pytorch\__init__.py

# 从segformer_pytorch模块中导入Segformer类
from segformer_pytorch.segformer_pytorch import Segformer

.\lucidrains\segformer-pytorch\setup.py

# 导入设置工具和查找包的函数
from setuptools import setup, find_packages

# 设置包的元数据
setup(
  name = 'segformer-pytorch', # 包的名称
  packages = find_packages(), # 查找所有包
  version = '0.0.6', # 版本号
  license='MIT', # 许可证
  description = 'Segformer - Pytorch', # 描述
  author = 'Phil Wang', # 作者
  author_email = 'lucidrains@gmail.com', # 作者邮箱
  url = 'https://github.com/lucidrains/segformer-pytorch', # 项目链接
  keywords = [ # 关键词
    'artificial intelligence',
    'deep learning',
    'transformers',
    'image segmentation'
  ],
  install_requires=[ # 安装依赖
    'einops>=0.3',
    'torch>=1.6'
  ],
  classifiers=[ # 分类器
    'Development Status :: 4 - Beta',
    'Intended Audience :: Developers',
    'Topic :: Scientific/Engineering :: Artificial Intelligence',
    'License :: OSI Approved :: MIT License',
    'Programming Language :: Python :: 3.6',
  ],
)

Self-Rewarding Language Model

Implementation of the training framework proposed in Self-Rewarding Language Model, from MetaAI

They really took the title of the DPO paper to heart.

This library also contains an implementation of SPIN, which Teknium of Nous Research has expressed optimism for.

Appreciation

Install

$ pip install self-rewarding-lm-pytorch

Usage

import torch
from torch import Tensor

from self_rewarding_lm_pytorch import (
    SelfRewardingTrainer,
    create_mock_dataset
)

from x_transformers import TransformerWrapper, Decoder

transformer = TransformerWrapper(
    num_tokens = 256,
    max_seq_len = 1024,
    attn_layers = Decoder(
        dim = 512,
        depth = 1,
        heads = 8
    )
)

sft_dataset = create_mock_dataset(100, lambda: (torch.randint(0, 256, (256,)), torch.tensor(1)))
prompt_dataset = create_mock_dataset(100, lambda: 'mock prompt')

def decode_tokens(tokens: Tensor) -> str:
    decode_token = lambda token: str(chr(max(32, token)))
    return ''.join(list(map(decode_token, tokens)))

def encode_str(seq_str: str) -> Tensor:
    return Tensor(list(map(ord, seq_str)))

trainer = SelfRewardingTrainer(
    transformer,
    finetune_configs = dict(
        train_sft_dataset = sft_dataset,
        self_reward_prompt_dataset = prompt_dataset,
        dpo_num_train_steps = 1000
    ),
    tokenizer_decode = decode_tokens,
    tokenizer_encode = encode_str,
    accelerate_kwargs = dict(
        cpu = True
    )
)

trainer(overwrite_checkpoints = True)

# checkpoints after each finetuning stage will be saved to ./checkpoints

SPIN can be trained as follows - it can also be added to the fine-tuning pipeline as shown in the final example in the readme.

import torch

from self_rewarding_lm_pytorch import (
    SPINTrainer,
    create_mock_dataset
)

from x_transformers import TransformerWrapper, Decoder

transformer = TransformerWrapper(
    num_tokens = 256,
    max_seq_len = 1024,
    attn_layers = Decoder(
        dim = 512,
        depth = 6,
        heads = 8
    )
)

sft_dataset = create_mock_dataset(100, lambda: (torch.randint(0, 256, (256,)), torch.tensor(1)))

spin_trainer = SPINTrainer(
    transformer,
    max_seq_len = 16,
    train_sft_dataset = sft_dataset,
    checkpoint_every = 100,
    spin_kwargs = dict(
        λ = 0.1,
    ),
)

spin_trainer()

Say you want to experiment with your own reward prompt (other than LLM-as-Judge). First you need to import the RewardConfig, next pass it into the trainer as reward_prompt_config


# first import

from self_rewarding_lm_pytorch import RewardConfig

# then say you want to try asking the transformer nicely

# reward_regex_template is the string that will be looked for in the LLM response, for parsing out the reward where {{ reward }} is defined as a number

trainer = SelfRewardingTrainer(
    transformer,
    ...,
    self_reward_prompt_config = RewardConfig(
        prompt_template = """
        Pretty please rate the following user prompt and response
        User: {{ prompt }}
        Response: {{ response }}

        Format your score as follows:
        Rating: <rating as integer from 0 - 10>
        """,
        reward_regex_template = """
        Rating: {{ reward }}
        """
    )
)

Finally, if you would like to experiment with arbitrary orders of fine-tuning, you will also have that flexiblity, by passing in FinetuneConfig instances into finetune_configs as a list

ex. say you want to carry out research on interleaving SPIN, External Rewarding, and Self-Rewarding

This idea originated from Teknium from a private discord channel.


# import the configs

from self_rewarding_lm_pytorch import (
    SFTConfig,
    SelfRewardDPOConfig,
    ExternalRewardDPOConfig,
    SelfPlayConfig,
)

trainer = SelfRewardingTrainer(
    model,
    finetune_configs = [
        SFTConfig(...),
        SelfPlayConfig(...),
        ExternalRewardDPOConfig(...),
        SelfRewardDPOConfig(...),
        SelfPlayConfig(...),
        SelfRewardDPOConfig(...)
    ],
    ...
)

trainer()

# checkpoints after each finetuning stage will be saved to ./checkpoints

Todo

  • generalize the sampling so that it can progress at different positions in the batch, fix all sampling to be batched. also allow for left padded sequences, in the case some people have transformers with relative positions that allow for that

  • handle eos

  • show an example for using your own reward prompt instead of default llm-as-judge

  • allow for different strategies for sampling the pairs

  • early stopper

    • handle break signal if all done on main process
    • accept eval module, could be either validation loss or something more sophisticated. returns a scalar tensor or single int / float
  • any order of sft, spin, self-rewarding dpo, dpo with external reward model

  • allow for a validation function on the rewards (say reward must be integer, float, in between some range etc)

  • figure out how best to handle different impl of kv cache, for now just do without

  • environment flag that auto-clears all checkpoint folders

Citation

@misc{yuan2024selfrewarding,
    title   = {Self-Rewarding Language Models}, 
    author  = {Weizhe Yuan and Richard Yuanzhe Pang and Kyunghyun Cho and Sainbayar Sukhbaatar and Jing Xu and Jason Weston},
    year    = {2024},
    eprint  = {2401.10020},
    archivePrefix = {arXiv},
    primaryClass = {cs.CL}
}
@article{Chen2024SelfPlayFC,
    title   = {Self-Play Fine-Tuning Converts Weak Language Models to Strong Language Models},
    author  = {Zixiang Chen and Yihe Deng and Huizhuo Yuan and Kaixuan Ji and Quanquan Gu},
    journal = {ArXiv},
    year    = {2024},
    volume  = {abs/2401.01335},
    url     = {https://api.semanticscholar.org/CorpusID:266725672}
}
@article{Rafailov2023DirectPO,
    title   = {Direct Preference Optimization: Your Language Model is Secretly a Reward Model},
    author  = {Rafael Rafailov and Archit Sharma and Eric Mitchell and Stefano Ermon and Christopher D. Manning and Chelsea Finn},
    journal = {ArXiv},
    year    = {2023},
    volume  = {abs/2305.18290},
    url     = {https://api.semanticscholar.org/CorpusID:258959321}
}
@inproceedings{Guo2024DirectLM,
    title   = {Direct Language Model Alignment from Online AI Feedback},
    author  = {Shangmin Guo and Biao Zhang and Tianlin Liu and Tianqi Liu and Misha Khalman and Felipe Llinares and Alexandre Rame and Thomas Mesnard and Yao Zhao and Bilal Piot and Johan Ferret and Mathieu Blondel},
    year    = {2024},
    url     = {https://api.semanticscholar.org/CorpusID:267522951}
}

.\lucidrains\self-rewarding-lm-pytorch\self_rewarding_lm_pytorch\dpo.py

# 导入必要的库
import os
from pathlib import Path
from copy import deepcopy
from functools import lru_cache
from collections import namedtuple
from dataclasses import dataclass

# 导入类型提示相关库
from beartype import beartype
from beartype.typing import Optional, Callable, Union, List
from torchtyping import TensorType

# 导入 PyTorch 相关库
import torch
from torch.nn import Module, Dropout
import torch.nn.functional as F
from torch.cuda.amp import autocast
from torch.optim.lr_scheduler import LinearLR
from torch.utils.data import Dataset, DataLoader
import torch.distributed as dist

# 导入加速库
from accelerate import Accelerator

# 导入 einops 和 einx 库
from einops import rearrange
from einx import get_at

# 导入 numpy 相关库
from numpy.lib.format import open_memmap

# 导入自定义工具函数
from pytorch_custom_utils import (
    get_adam_optimizer,
    OptimizerWithWarmupSchedule
)

# 导入加速相关工具函数
from pytorch_custom_utils.accelerate_utils import (
    model_forward_contexts
)

# 导入自定义工具函数
from pytorch_custom_utils.utils import (
    masked_mean,
    maybe_and_mask
)

# 导入进度条库
from tqdm import tqdm

# 导入 EMA 库
from ema_pytorch import EMA

# 定义辅助函数

# 判断变量是否存在
def exists(v):
    return v is not None

# 如果变量存在则返回变量值,否则返回默认值
def default(v, d):
    return v if exists(v) else d

# 生成循环迭代器
def cycle(dl):
    while True:
        for batch in dl:
            yield batch

# 判断是否处于分布式环境
@lru_cache(maxsize=None)
def is_distributed():
    return dist.is_initialized() and dist.get_world_size() > 1

# 从模型和序列中获取对数概率
def log_prob_from_model_and_seq(model, seq):
    logits = model(seq)
    log_probs = logits.log_softmax(dim=-1)
    return get_at('b n [c], b n -> b n', log_probs, seq)

# 根据长度和序列生成掩码
def prompt_mask_from_len(lengths, seq):
    seq_len, device = seq.shape[-1], seq.device
    return torch.arange(seq_len, device=device) < rearrange(lengths, '... -> ... 1')

# 设置模型中的 Dropout 概率
def set_dropout_(model: Module, prob: float):
    for module in model.modules():
        if isinstance(module, Dropout):
            module.p = prob

# 使用线性衰减的 Adam 优化器
def adam_optimizer_with_linear_decay(
    model: Module,
    start_learning_rate: float,
    end_learning_rate: float,
    num_decay_steps: int,
    accelerator: Accelerator,
    weight_decay: float,
    adam_kwargs: dict = dict(),
) -> OptimizerWithWarmupSchedule:

    adam = get_adam_optimizer(
        model.parameters(),
        lr=start_learning_rate,
        wd=weight_decay
    )

    scheduler = None
    if start_learning_rate != end_learning_rate:
        scheduler = LinearLR

    return OptimizerWithWarmupSchedule(
        optimizer=adam,
        accelerator=accelerator,
        scheduler=LinearLR,
        scheduler_kwargs=dict(
            start_factor=1.,
            end_factor=end_learning_rate / start_learning_rate,
            total_iters=num_decay_steps
        )
    )

# 提前停止

# 定义提前停止返回结果的数据类
@dataclass
class EarlyStopperReturn:
    should_stop: bool
    score: float

# 提前停止类
class EarlyStopper(Module):
    @beartype
    def __init__(
        self,
        model: Module,
        evaluator: Module,
        accelerator: Accelerator,
        calculate_should_stop: Callable[..., bool] = lambda scores: len(scores) > 1 and scores[-1] > scores[-2],
        early_stop_checkpoint_folder: str = './early-stop-checkpoint'
    ):
        super().__init__()
        self.model = model
        self.evaluator = evaluator
        self.accelerator = accelerator

        self.scores: List[Union[int, float]] = []
        self.calculate_should_stop = calculate_should_stop

        self.early_stop_checkpoint_folder = Path(early_stop_checkpoint_folder)
        self.early_stop_checkpoint_folder.mkdir(exist_ok=True, parents=True)

        self.register_buffer('break_signal', torch.tensor(0.))

    # 清空提前停止检查点文件夹
    def clear_early_checkpoint_folder(self):
        for file in self.early_stop_checkpoint_folder.glob('*.pt'):
            os.remove(file)

    # 判断是否为主进程
    @property
    def is_main(self):
        return self.accelerator.is_main_process

    # 等待所有进程完成
    def wait(self):
        return self.accelerator.wait_for_everyone()
    # 保存模型的状态到指定路径
    def save(self, path: str, overwrite: bool = False):
        # 等待所有操作完成
        self.wait()

        # 如果是主进程
        if self.is_main:

            # 设置保存路径为早停检查点文件夹下的指定路径
            path = self.early_stop_checkpoint_folder / path

            # 如果文件已存在且不允许覆盖,则抛出异常
            assert not path.exists() or overwrite, f'file already exists'

            # 构建保存的数据包,包含模型的状态字典
            pkg = dict(
                model = self.model.state_dict()
            )

            # 保存数据包到指定路径
            torch.save(pkg, str(path))

    # 前向传播函数,返回早停器的结果
    @torch.no_grad()
    def forward(self) -> EarlyStopperReturn:
        # 设置模型为评估模式
        self.model.eval()

        score = None

        # 如果是主进程
        if self.is_main:

            # 计算模型的评估分数
            score = self.evaluator(self.model)

            # 如果评估分数是张量
            if torch.is_tensor(score):
                # 确保张量元素个数为1
                assert score.numel() == 1
                # 将张量展平为标量
                score = score.flatten().item()

            # 确保评估分数为整数或浮点数
            assert isinstance(score, (int, float))

            # 将评估分数添加到分数列表中
            self.scores.append(score)

            # 计算是否应该停止训练
            should_stop = self.calculate_should_stop(self.scores)

            # 如果应该停止,则设置中断信号为1
            if should_stop:
                self.break_signal.copy_(1.)

        # 处理分布式环境下的早停中断信号
        if is_distributed():
            dist.all_reduce(self.break_signal)
            should_stop = self.break_signal.item() == 1.

        # 处理在评估分数下降之前恢复到检查点的逻辑
        if should_stop:
            # 获取上一个评估分数对应的检查点文件名
            prev_checkpoint_filename = f'model.ckpt.{len(self.scores) - 1}.pt'
            prev_checkpoint_path = self.early_stop_checkpoint_folder / prev_checkpoint_filename
            # 加载上一个检查点的模型状态
            pkg = torch.load(str(prev_checkpoint_path))

            self.model.load_state_dict(pkg['model'])
        else:
            # 生成当前评估分数对应的检查点文件名,并保存当前模型状态
            checkpoint_filename = f'model.ckpt.{len(self.scores)}.pt'
            self.save(checkpoint_filename)

        # 返回早停器的结果,包括评估分数和是否应该停止训练的标志
        return EarlyStopperReturn(score, self.break_signal.item() == 1)
# 从两个 memmap numpy 文件中读取数据集

# 数据集包含首选和非首选序列的形状 - (<样本数>, <偏好 (2) - 首选后跟非首选>, <序列长度>)
# 提示长度 (<样本数>,)

class DPODataset(Dataset):
    def __init__(
        self,
        data_folder: str = './',
        preference_seq_memmap_file: str = 'preference_seq.memmap.npy',
        prompt_len_memmap_file: str = 'prompt_len.memmap.npy',
    ):
        self.data_folder = Path(data_folder)
        assert self.data_folder.exists() and self.data_folder.is_dir()

        preference_seq_memmap_path = self.data_folder / preference_seq_memmap_file
        prompt_len_memmap_path = self.data_folder / prompt_len_memmap_file

        assert preference_seq_memmap_path.exists()
        assert prompt_len_memmap_path.exists()

        self.paired_sequences = open_memmap(str(preference_seq_memmap_path), dtype = 'int', mode = 'r')
        self.prompt_len = open_memmap(str(prompt_len_memmap_path), dtype = 'int', mode = 'r')

        self.seq_len = self.paired_sequences.shape[1]
        assert self.paired_sequences.shape[0] == self.prompt_len.shape[0]

    def __len__(self):
        return self.paired_sequences.shape[0]

    def __getitem__(self, idx):
        sequences = self.paired_sequences[idx].copy()
        prompt_lens = self.prompt_len[idx].copy()

        preferred_seq, unpreferred_seq = sequences

        return preferred_seq, unpreferred_seq, prompt_lens

# 主类

class DPO(Module):
    def __init__(
        self,
        model: Module,
        *,
        beta = 0.1,
        ref_model_ema_decay = 1.,
        pad_id: Optional[int] = None,
        ema_kwargs: dict = dict()
    ):
        super().__init__()
        self.policy_model = model

        self.ref_model = EMA(
            model,
            beta = ref_model_ema_decay,
            **ema_kwargs
        )

        self.beta = beta
        self.pad_id = pad_id

    def update_reference_model_with_policy(self):
        self.ref_model.copy_params_from_model_to_ema()

    def update_ema(self):
        self.ref_model.update()

    def parameters(self):
        return self.policy_model.parameters()

    @property
    def device(self):
        return next(self.parameters()).device

    @autocast(enabled = False)
    def forward(
        self,
        preferred_seq: TensorType['b', 'n', int],
        unpreferred_seq: TensorType['b', 'n', int],
        prompt_len: TensorType['b', int],
        preferred_seq_mask: Optional[TensorType['b', 'n', bool]] = None,
        unpreferred_seq_mask: Optional[TensorType['b', 'n', bool]] = None
    # 设置策略模型为训练模式
    self.policy_model.train()

    """
    b - batch
    n - sequence length
    """

    # 根据提示长度和首选/非首选序列生成掩码
    preferred_prompt_mask = prompt_mask_from_len(prompt_len, preferred_seq)
    unpreferred_prompt_mask = prompt_mask_from_len(prompt_len, unpreferred_seq)

    """
    Following Appendix B in https://arxiv.org/abs/2305.18290
    """

    # 如果存在填充 ID
    if exists(self.pad_id):
        # 确保首选序列掩码和非首选序列掩码不存在
        assert not exists(preferred_seq_mask)
        assert not exists(unpreferred_seq_mask)
        # 创建首选序列掩码
        preferred_seq_mask = preferred_seq != self.pad_id
        preferred_seq.masked_fill_(~preferred_seq_mask, 0)
        # 创建非首选序列掩码
        unpreferred_seq_mask = unpreferred_seq != self.pad_id
        unpreferred_seq.masked_fill_(~unpreferred_seq_mask, 0)            

    # 在不计算梯度的情况下执行以下操作
    with torch.no_grad():
        # 设置参考模型为评估模式
        self.ref_model.eval()
        # 计算首选序列和非首选序列在参考模型下的对数概率
        ref_preferred_logprob = log_prob_from_model_and_seq(self.ref_model, preferred_seq)
        ref_unpreferred_logprob = log_prob_from_model_and_seq(self.ref_model, unpreferred_seq)

    # 计算策略模型下首选序列和非首选序列的对数概率
    policy_preferred_logprob = log_prob_from_model_and_seq(self.policy_model, preferred_seq)
    policy_unpreferred_logprob = log_prob_from_model_and_seq(self.policy_model, unpreferred_seq)

    # 计算掩码平均值

    # 对策略模型和参考模型下的首选序列和非首选序列的对数概率进行掩码平均值计算
    policy_preferred_logprob, ref_preferred_logprob = [masked_mean(seq, maybe_and_mask(preferred_seq_mask, ~preferred_prompt_mask)) for seq in (policy_preferred_logprob, ref_preferred_logprob)]
    policy_unpreferred_logprob, ref_unpreferred_logprob = [masked_mean(seq, maybe_and_mask(unpreferred_seq_mask, ~unpreferred_prompt_mask)) for seq in (policy_unpreferred_logprob, ref_unpreferred_logprob)]

    # 计算 DPO 损失

    # 计算策略模型和参考模型下的首选序列和非首选序列的对数概率之差
    policy_logratios = policy_preferred_logprob - policy_unpreferred_logprob
    ref_logratios = ref_preferred_logprob - ref_unpreferred_logprob

    # 计算损失值
    losses = -F.logsigmoid(self.beta * (policy_logratios - ref_logratios))

    # 返回损失值的平均值
    return losses.mean()
# trainer class

class DPOTrainer(Module):
    # 初始化方法
    @beartype
    def __init__(
        self,
        dpo: Union[DPO, Module],
        *,
        dataset_generator: Optional[Callable[[], Dataset]] = None,
        accelerator: Optional[Accelerator] = None,
        batch_size: int = 16,
        grad_accum_steps: int = 2,
        num_decay_steps: int = 1000,
        num_train_steps: Optional[int] = None,
        learning_rate: float = 3e-4,
        weight_decay: float = 0.,
        train_dataset: Optional[Dataset] = None,
        valid_dataset: Optional[Dataset] = None,
        start_learning_rate: float = 1e-6,
        end_learning_rate: float = 1e-7,
        early_stopper: Optional[EarlyStopper] = None,
        dropout: float = 0.1,
        check_early_stop_every: int = 200,
        early_stopper_eval_module: Optional[Module] = None,
        adam_kwargs: dict = dict(),
        accelerate_kwargs: dict = dict(),
        dpo_kwargs: dict = dict(
            beta = 0.1,
            ref_model_ema_decay = 1.
        ),
        early_stopper_kwargs: dict = dict()
    ):
        # 调用父类的初始化方法
        super().__init__()

        # 如果 dpo 不是 DPO 类型,则使用 dpo_kwargs 创建 DPO 对象
        if not isinstance(dpo, DPO):
            dpo = DPO(dpo, **dpo_kwargs)

        # 设置 DPO 对象的 dropout
        set_dropout_(dpo, dropout)

        # 如果 accelerator 不存在,则使用 accelerate_kwargs 创建 Accelerator 对象
        if not exists(accelerator):
            accelerator = Accelerator(**accelerate_kwargs)

        # 设置 accelerator
        self.accelerator = accelerator

        # 准备模型
        self.model = accelerator.prepare(dpo)
        self.dropout = dropout

        # 设置数据集生成器
        self.dataset_generator = dataset_generator

        # 设置批量大小和梯度累积步数
        self.batch_size = batch_size
        self.grad_accum_steps = grad_accum_steps

        # 使用 adam_optimizer_with_linear_decay 创建优化器
        self.optimizer = adam_optimizer_with_linear_decay(
            dpo,
            start_learning_rate,
            end_learning_rate,
            num_decay_steps = num_decay_steps,
            accelerator = accelerator,
            weight_decay = weight_decay,
            adam_kwargs = adam_kwargs
        )

        # 如果存在 early_stopper_eval_module,则创建 EarlyStopper 对象
        self.early_stopper = None
        if exists(early_stopper_eval_module):
            self.early_stopper = EarlyStopper(
                dpo.policy_model,
                evaluator = early_stopper_eval_module,
                accelerator = self.accelerator,
                **early_stopper_kwargs
            )

        # 设置检查早停的频率
        self.check_early_stop_every = check_early_stop_every

        # 如果存在 train_dataset,则创建 DataLoader 对象
        self.train_dataloader = None
        if exists(train_dataset):
            self.train_dataloader = DataLoader(train_dataset, batch_size = batch_size, shuffle = True, drop_last = True)
            self.train_dataloader = accelerator.prepare(self.train_dataloader)

        # 如果存在 valid_dataset,则创建 DataLoader 对象
        self.valid_dataloader = None
        if exists(valid_dataset):
            self.valid_dataloader = DataLoader(valid_dataset, batch_size = batch_size)

        # 初始化步数和训练步数
        self.steps = 0
        self.num_train_steps = num_train_steps

    # 获取未包装的模型
    @property
    def unwrapped_model(self):
        return self.accelerator.unwrap_model(self.model)

    # 判断是否为主进程
    @property
    def is_main(self):
        return self.accelerator.is_main_process

    # 打印信息
    def print(self, *msg):
        self.accelerator.print(*msg)

    # 等待所有进程完成
    def wait(self):
        return self.accelerator.wait_for_everyone()

    # 记录日志
    def log(self, **data):
        self.accelerator.log(data, step = self.steps)

    # 前向传播方法
    def forward(
        self,
        train_self_reward_dataset: Optional[Dataset] = None
        ):
            # 检查是否存在数据集生成器,如果存在则生成训练用的自我奖励数据集
            if exists(self.dataset_generator):
                train_self_reward_dataset = self.dataset_generator()

            # 更新参考模型的策略
            self.model.update_reference_model_with_policy()

            # 如果存在早停器,清除早停检查点文件夹
            if exists(self.early_stopper):
                self.early_stopper.clear_early_checkpoint_folder()

            # 获取训练数据加载器
            train_dataloader = self.train_dataloader

            # 如果训练数据加载器不存在,则创建一个并准备好
            if not exists(train_dataloader):
                assert exists(train_self_reward_dataset)
                train_dataloader = DataLoader(train_self_reward_dataset, batch_size = self.batch_size, drop_last = True, shuffle = True)
                train_dataloader = self.accelerator.prepare(train_dataloader)

            # 创建数据加载器的迭代器
            iter_dl = cycle(train_dataloader)

            # 创建进度条
            pbar = tqdm(desc = 'dpo fine-tuning', total = self.num_train_steps)

            # 设置模型的 dropout
            set_dropout_(self.model, self.dropout)

            # 进入训练循环
            while True:
                self.model.train()

                # 遍历模型前向上下文
                for forward_context in model_forward_contexts(self.accelerator, self.model, self.grad_accum_steps):
                    with forward_context():
                        batch = next(iter_dl)

                        # 计算 DPO 损失
                        dpo_loss = self.model(*batch)
                        self.accelerator.backward(dpo_loss / self.grad_accum_steps)

                # 打印 DPO 损失值
                self.print(f'dpo loss: {dpo_loss.item():.3f}')
                self.log(loss = dpo_loss.item())

                # 执行优化器的步骤
                self.optimizer.step()
                self.optimizer.zero_grad()

                # 等待
                self.wait()

                # 更新指数移动平均模型
                self.unwrapped_model.update_ema()

                # 更新步数并更新进度条
                self.steps += 1
                pbar.update(1)

                # 如果达到训练步数上限,则结束训练
                if exists(self.num_train_steps) and self.steps >= self.num_train_steps:
                    break

                # 检查是否需要早停
                self.wait()

                if not (self.steps % self.check_early_stop_every) and exists(self.early_stopper):

                    # 执行早停逻辑
                    early_stop_return = self.early_stopper()

                    if self.is_main:
                        self.print(f'valid dpo loss: {early_stop_return.score:.3f}')
                        self.log(dpo_valid_score = early_stop_return.score)

                    if early_stop_return.should_stop:
                        self.print('early stopping')
                        break

            # 关闭进度条
            pbar.close()
            self.print('dpo training finished')

.\lucidrains\self-rewarding-lm-pytorch\self_rewarding_lm_pytorch\mocks.py

# 导入 functools 模块中的 wraps 函数
# 导入 typing 模块中的 Type 和 Any 类型
# 导入 torch.utils.data 模块中的 Dataset 类

from functools import wraps
from typing import Type, Any
from torch.utils.data import Dataset

# 创建一个装饰器函数,根据传入的值返回一个装饰器
def always(val):
    # 装饰器函数,接受一个函数作为参数
    def decorator(fn):
        # 内部函数,使用 functools 模块中的 wraps 函数装饰传入的函数
        @wraps(fn)
        # 接受任意参数并根据传入的值返回结果
        def inner(*args, **kwargs):
            # 如果传入的值是可调用的函数,则调用该函数并返回结果
            if callable(val):
                return val()

            # 否则直接返回传入的值
            return val
        return inner
    return decorator

# 创建一个模拟数据集的函数,根据传入的长度和输出值创建一个 Dataset 对象
def create_mock_dataset(
    length: int,
    output: Any
) -> Dataset:

    # 定义一个内部类 MockDataset,继承自 Dataset 类
    class MockDataset(Dataset):
        # 重写 __len__ 方法,返回传入的长度
        def __len__(self):
            return length

        # 重写 __getitem__ 方法,根据索引返回传入的输出值
        def __getitem__(self, idx):
            # 如果传入的输出值是可调用的函数,则调用该函数并返回结果
            if callable(output):
                return output()

            # 否则直接返回传入的输出值
            return output

    # 返回创建的 MockDataset 对象
    return MockDataset()

.\lucidrains\self-rewarding-lm-pytorch\self_rewarding_lm_pytorch\sampling_utils.py

import torch  # 导入 PyTorch 库
import torch.nn.functional as F  # 导入 PyTorch 中的函数模块
from torch import Tensor  # 导入 PyTorch 中的张量
from torch.nn import Module  # 导入 PyTorch 中的神经网络模块
from torch.nn.utils.rnn import pad_sequence  # 导入 PyTorch 中的序列填充函数

from beartype import beartype  # 导入 beartype 库中的类型检查装饰器
from beartype.typing import Optional, Callable, List, Tuple  # 导入 beartype 库中的类型注解

from einops import rearrange  # 导入 einops 库中的重排函数

def exists(v):  # 定义函数,判断变量是否存在
    return v is not None

def default(v, d):  # 定义函数,如果变量存在则返回变量,否则返回默认值
    return v if exists(v) else d

# 采样辅助函数

def log(t, eps = 1e-20):  # 定义函数,计算张量的对数
    return torch.log(t.clamp(min = eps))

def gumbel_noise(t):  # 定义函数,生成 Gumbel 噪声
    noise = torch.zeros_like(t).uniform_(0, 1)
    return -log(-log(noise))

def gumbel_sample(t, temperature = 1., dim = -1, keepdim = True, eps = 1e-10):  # 定义函数,进行 Gumbel 采样
    return ((t / max(temperature, eps)) + gumbel_noise(t)).argmax(dim = dim, keepdim = keepdim)

# nucleus

def top_p(logits, thres = 0.9):  # 定义函数,根据 top-p 算法进行筛选
    sorted_logits, sorted_indices = torch.sort(logits, descending = True)
    cum_probs = torch.cumsum(F.softmax(sorted_logits, dim = -1), dim = -1)

    sorted_indices_to_remove = cum_probs > thres
    sorted_indices_to_remove = F.pad(sorted_indices_to_remove, (1, -1), value = False)

    sorted_logits[sorted_indices_to_remove] = float('-inf')
    return sorted_logits.scatter(1, sorted_indices, sorted_logits)

# topk

def top_k(logits, frac_num_tokens = 0.1, k: Optional[int] = None):  # 定义函数,根据 top-k 算法进行筛选
    num_tokens = logits.shape[-1]

    k = default(k, ceil(frac_num_tokens * num_tokens))
    k = min(k, num_tokens)

    val, ind = torch.topk(logits, k)
    probs = torch.full_like(logits, float('-inf'))
    probs.scatter_(1, ind, val)
    return probs

# 解码

@torch.no_grad()  # 禁用梯度计算
@beartype  # 使用 beartype 类型检查装饰器
def sample(  # 定义采样函数
    net: Module,  # 神经网络模型
    prompts,  # 输入的提示序列
    seq_len: int,  # 生成序列的长度
    temperature = 1.,  # 温度参数
    filter_fn: Callable = top_p,  # 筛选函数,默认为 top_p
    filter_kwargs: dict = dict(),  # 筛选函数的参数
    pad_id: int = -1,  # 填充标识���
    eos_id: Optional[int] = None,  # 结束标识符
    output_keep_prompt = False  # 是否保留提示序列
):
    device = next(net.parameters()).device  # 获取神经网络模型的设备
    net.eval()  # 设置神经网络模型为评估模式

    if isinstance(prompts, (tuple, list)):  # 如果提示序列是元组或列表
        prompts = pad_sequence(prompts, batch_first = True, padding_value = pad_id)  # 对提示序列进行填充

    batch, prompts_tensor_len = prompts.shape  # 获取提示序列的批次和长度

    batch_arange = torch.arange(batch, device = device)[..., None]  # 生成批次索引

    prompt_lens = (prompts != pad_id).sum(dim = -1)  # 计算提示序列的长度
    curr_seq_indices = prompt_lens[..., None]  # 当前序列索引

    out = prompts.clone()  # 克隆提示序列

    while (curr_seq_indices < seq_len).any():  # 当当前序列索引小于生成序列长度时循环
        out = F.pad(out, (0, 1), value = pad_id)  # 对输出序列进行填充

        net_input = out.masked_fill(out == pad_id, 0)  # 将填充部分替换为零

        logits = net(net_input)  # 输入神经网络模型,获取输出 logits

        logits = logits[batch_arange, curr_seq_indices]  # 获取当前序列的 logits
        logits = rearrange(logits, 'b 1 d -> b d')  # 重排 logits 的维度

        logits = filter_fn(logits, **filter_kwargs)  # 使用筛选函数对 logits 进行筛选
        sampled_tokens = gumbel_sample(logits, temperature = temperature, dim = -1)  # 使用 Gumbel 采样获取生成的 token

        out[batch_arange, curr_seq_indices] = sampled_tokens  # 将生成的 token 放入输出序列

        curr_seq_indices += 1  # 当前序列索引加一
        curr_seq_indices.clamp_(max = seq_len)  # 限制当前序列索引的最大值为生成序列长度

        if not exists(eos_id):  # 如果不存在结束标识符
            continue  # 继续循环

        is_eos_mask = out == eos_id  # 获取结束标识符的掩码
        all_eos = is_eos_mask.any(dim = -1).all()  # 判断是否所有序列都包含结束标识符

        if all_eos:  # 如果所有序列都包含结束标识符
            break  # 跳出循环

    out = out[:, :seq_len]  # 截取生成的序列长度为指定长度

    if exists(eos_id):  # 如果存在结束标识符
        is_eos_mask = out == eos_id  # 获取结束标识符的掩码
        after_eos_mask = F.pad(is_eos_mask.cumsum(dim = -1) > 0, (1, -1), value = False)  # 获取结束标识符后的掩码
        out = out.masked_fill_(after_eos_mask, pad_id)  # 将结束标识符后的部分替换为填充标识符

    if output_keep_prompt:  # 如果需要保留提示序列
        return out  # 返回生成的序列

    prompt_mask = torch.arange(out.shape[-1], device = device) < prompt_lens[..., None]  # 生成提示序列的掩码

    generated_seq_mask = out != pad_id & ~prompt_mask  # 生成序列的掩码
    seq_lens = generated_seq_mask.sum(dim = -1).tolist()  # 计算生成序列的长度

    return out[generated_seq_mask].split(seq_lens)  # 返回拆分后的生成序列

.\lucidrains\self-rewarding-lm-pytorch\self_rewarding_lm_pytorch\self_rewarding_lm_pytorch.py

# 导入所需的库
import re
import sys
from functools import partial
from random import randrange
from copy import deepcopy
from pathlib import Path
from dataclasses import dataclass, field
from functools import wraps
from textwrap import dedent

# 导入类型提示相关的库
from beartype import beartype
from beartype.typing import Optional, Dict, List, Tuple, Union, Callable
from torchtyping import TensorType

# 导入 PyTorch 相关的库
import torch
from torch import Tensor
import torch.nn.functional as F
from torch.nn import Module, ModuleList, Dropout
from torch.utils.data import Dataset, ConcatDataset, DataLoader
from torch.nn.utils.rnn import pad_sequence

# 导入 NumPy 相关的库
import numpy as np
from numpy.lib.format import open_memmap

# 导入自定义的模块
from self_rewarding_lm_pytorch.dpo import (
    DPO,
    DPODataset,
    DPOTrainer,
    EarlyStopper,
    set_dropout_,
    adam_optimizer_with_linear_decay
)

from self_rewarding_lm_pytorch.spin import (
    SPIN,
    SPINTrainer
)

from einops import rearrange, repeat

from accelerate import Accelerator

from pytorch_custom_utils.utils import pad_or_slice_to

from pytorch_custom_utils.accelerate_utils import (
    model_forward_contexts
)

from self_rewarding_lm_pytorch.sampling_utils import (
    sample,
    top_p,
    top_k
)

from self_rewarding_lm_pytorch.mocks import always

from tqdm import tqdm

# 如果系统是 32 位,则给出警告
if sys.maxsize <= (2 ** 32):
    print('you need to be on 64 bit system to use memmapped files of > 2GB')

# 基本模板引擎
import jinja2
jinja2_env = jinja2.Environment()

# 从 Jinja 模板中查找变量
def find_variables_from_jinja_template(template: str):
    from jinja2 import meta
    ast = jinja2_env.parse(template)
    return meta.find_undeclared_variables(ast)

# 辅助函数
# 判断变量是否存在
def exists(v):
    return v is not None

# 如果变量存在则返回变量,否则返回默认值
def default(v, d):
    return v if exists(v) else d

# 返回数组的第一个元素
def first(arr):
    return arr[0]

# 无限循环生成器
def cycle(dl):
    while True:
        for batch in dl:
            yield batch

# 返回输入本身
def identity(t, *args, **kwargs):
    return t

# 根据长度生成掩码
def prompt_mask_from_len(length, seq):
    seq_len, device = seq.shape[-1], seq.device
    return torch.arange(seq_len, device=device) < rearrange(length, '... -> ... 1')

# 将输入转换为元组
def cast_tuple(t, length=1, validate=False):
    out = t if isinstance(t, tuple) else ((t,) * length)
    assert not validate or len(out) == length
    return out

# 转换输入数据类型的装饰器
def cast_input(cast_fn):
    def decorator(fn):
        @wraps(fn)
        def inner(t, *args, **kwargs):
            t = cast_fn(t)
            return fn(t, *args, **kwargs)
        return inner

    return decorator

# 转换输出数据类型的装饰器
def cast_output(cast_fn):
    def decorator(fn):
        @wraps(fn)
        def output(*args, **kwargs):
            out = fn(*args, **kwargs)
            out = cast_fn(out)
            return out
        return output

    return decorator

# 常量
# llm-as-judge prompt
# https://openreview.net/forum?id=uccHPGDlao

# 默认的评分模板
DEFAULT_LLM_AS_JUDGE_PROMPT = """
Review the user’s question and the corresponding response using the additive 5-point
scoring system described below. Points are accumulated based on the satisfaction of each
criterion:
- Add 1 point if the response is relevant and provides some information related to
the user’s inquiry, even if it is incomplete or contains some irrelevant content.
- Add another point if the response addresses a substantial portion of the user’s question,
but does not completely resolve the query or provide a direct answer.
- Award a third point if the response answers the basic elements of the user’s question in a
useful way, regardless of whether it seems to have been written by an AI Assistant or if it
has elements typically found in blogs or search results.
- Grant a fourth point if the response is clearly written from an AI Assistant’s perspective,
addressing the user’s question directly and comprehensively, and is well-organized and
helpful, even if there is slight room for improvement in clarity, conciseness or focus.
- Bestow a fifth point for a response that is impeccably tailored to the user’s question
"""
# 定义默认的奖励正则表达式模板
DEFAULT_REWARD_REGEX_TEMPLATE = """
Score: {{ reward }}
"""

# 创建解析奖励函数,根据奖励正则表达式模板
def create_parse_reward_fn(reward_regex_template):
    # 确保奖励模板包含"score"变量
    assert find_variables_from_jinja_template(reward_regex_template) == {'reward'}, 'reward template must include "score" variable'
    # 渲染奖励正则表达式模板
    reward_regex_str = jinja2_env.from_string(reward_regex_template).render(reward = "([0-9\.]+)")

    # 解析奖励函数
    def parse_reward_fn(llm_response: str) -> float:
        # 使用正则表达式匹配奖励
        result = re.search(rf"{reward_regex_str}", llm_response)

        # 如果没有匹配结果或者没有分组
        if not exists(result) or result.groups == 0:
            return None

        # 如果匹配结果不是数字
        if not result.groups(1).isnumeric():
            return None

        # 返回解析后的奖励值
        return float(result.groups(1))

    return parse_reward_fn

# 奖励配置
@dataclass
class RewardConfig:
    prompt_template: str
    reward_regex_template: Optional[str] = None
    parse_reward: Optional[Callable[[str], Optional[float]]] = None
    template_fn: Optional[Callable[..., str]] = None
    auto_dedent: bool = True

    # 初始化函数
    def init(self):

        # 可能需要去除缩进
        if self.auto_dedent:
            self.prompt_template = dedent(self.prompt_template)

            # 如果奖励正则表达式模板存在,也需要去除缩进
            if exists(self.reward_regex_template):
                self.reward_regex_template = dedent(self.reward_regex_template)

        # 初始化用于渲染提示和响应模板的函数
        prompt_template = self.prompt_template
        assert find_variables_from_jinja_template(prompt_template) == {'prompt', 'response'}, 'template must include prompt and response templating variables'
        self.template_fn = jinja2_env.from_string(prompt_template).render

        # 如果没有传入解析奖励函数,则根据奖励正则表达式模板创建解析函数
        if not exists(self.parse_reward):
            assert exists(self.reward_regex_template), 'reward_regex_template must be given if parse_reward is not passed in'
            self.parse_reward = create_parse_reward_fn(self.reward_regex_template)

        return self

# 默认奖励提示配置
SELF_REWARD_PROMPT_CONFIG = dict(
    default = RewardConfig(
        prompt_template = DEFAULT_LLM_AS_JUDGE_PROMPT,
        reward_regex_template = DEFAULT_REWARD_REGEX_TEMPLATE
    )
)

# 默认的有效奖励对选择函数
default_is_valid_reward_pair = lambda preferred_reward, unpreferred_reward: (preferred_reward != unpreferred_reward).all()

# 默认的选择配对奖励函数
@beartype
def default_pick_paired_rewards_fn(rewards: Tensor):
    is_nan_mask = torch.isnan(rewards)
    rewards_max, rewards_min = rewards.clone(), rewards.clone()
    rewards_max[is_nan_mask] = -1e6
    rewards_min[is_nan_mask] = 1e6
    return torch.stack((rewards_max.argmax(dim = -1), rewards_min.argmin(dim = -1)))

# SFT训练器类
class SFTTrainer(Module):
    @beartype
    # 初始化模型训练器,设置各种参数
    def __init__(
        self,
        model: Module,
        *,
        accelerator: Accelerator,
        train_dataset: Union[List[Dataset], Dataset],
        valid_dataset: Optional[Dataset] = None,
        batch_size: int = 16,
        grad_accum_steps: int = 2,
        num_epochs: int = 3,
        start_learning_rate: float = 5.5e-6,
        end_learning_rate: float = 1.1e-6,
        learning_rate_num_decay_steps: Optional[int] = None,
        dropout: float = 0.,
        weight_decay: float = 0.,
        ignore_index: int = -1,
        adam_kwargs: dict = dict(),
        valid_every: int = 1
    ):
        # 调用父类的初始化方法
        super().__init__()
        # 设置加速器和模型
        self.accelerator = accelerator
        self.model = model
        self.dropout = dropout

        self.num_epochs = num_epochs
        self.ignore_index = ignore_index

        # 如果训练数据集是列表,则将其合并为一个数据集
        if isinstance(train_dataset, list):
            train_dataset = ConcatDataset(train_dataset)

        # 创建训练数据加载器
        self.train_dataloader = DataLoader(train_dataset, batch_size = batch_size, drop_last = True, shuffle = True)

        # 计算总的训练步数
        self.num_train_steps = len(self.train_dataloader) // grad_accum_steps * num_epochs
        self.grad_accum_steps = grad_accum_steps

        # 准备模型和训练数据加载器
        (
            self.model,
            self.train_dataloader
        ) = self.accelerator.prepare(
            self.model,
            self.train_dataloader
        )

        # 如果学习率衰减步数不存在,则设置为训练数据集长度的一半
        if not exists(learning_rate_num_decay_steps):
            learning_rate_num_decay_steps = len(train_dataset) // 2

        # 创建优化器
        self.optimizer = adam_optimizer_with_linear_decay(
            model,
            start_learning_rate,
            end_learning_rate,
            num_decay_steps = learning_rate_num_decay_steps,
            accelerator = accelerator,
            weight_decay = weight_decay,
            adam_kwargs = adam_kwargs
        )

        self.valid_every = valid_every

        self.valid_dataloader = None
        # 如果验证数据集存在,则创建验证数据加载器
        if exists(valid_dataset):
            self.valid_dataloader = DataLoader(valid_dataset, batch_size = batch_size)

        self.steps = 0

    # 记录日志
    def log(self, **data):
        self.accelerator.log(data, step = self.steps)

    # 等待所有进程完成
    def wait(self):
        return self.accelerator.wait_for_everyone()

    # 计算交叉熵损失
    def get_cross_entropy_loss(
        self,
        seq: TensorType['batch', 'seq', int],
        prompt_len_or_mask: Union[
            TensorType['batch', int],
            TensorType['batch', 'seq', bool]
        ]
    ):
        # 根据输入的 prompt_len_or_mask 类型,生成 prompt_mask
        if prompt_len_or_mask.dtype == torch.long:
            prompt_mask = prompt_mask_from_len(prompt_len_or_mask, seq)
        else:
            prompt_mask = prompt_len_or_mask

        # 将输入序列和标签序列分开
        seq, labels = seq[:, :-1], seq[:, 1:]

        # 根据 prompt_mask 填充标签
        labels.masked_fill_(prompt_mask[:, 1:], self.ignore_index)

        # 获取模型的预测结果
        logits = self.model(seq)

        # 计算交叉熵损失
        return F.cross_entropy(
            rearrange(logits, 'b n l -> b l n'),
            labels,
            ignore_index = self.ignore_index
        )
    # 定义 forward 方法,用于模型的前向传播
    def forward(self):
        
        # 从训练数据加载器中创建一个循环迭代器
        train_dl_iter = cycle(self.train_dataloader)
        
        # 设置模型中的 dropout 层
        set_dropout_(self.model, self.dropout)
        
        # 循环执行训练步骤
        for _ in tqdm(range(self.num_train_steps), desc='sft fine-tuning'):
            self.model.train()
            
            # 遍历模型前向传播上下文
            for forward_context in model_forward_contexts(self.accelerator, self.model, self.grad_accum_steps):
                with forward_context():
                    # 从训练数据加载器中获取下一个序列和提示长度或掩码
                    seq, prompt_len_or_mask = next(train_dl_iter)
                    
                    # 计算交叉熵损失
                    loss = self.get_cross_entropy_loss(seq, prompt_len_or_mask)
                    
                    # 反向传播计算梯度
                    self.accelerator.backward(loss / self.grad_accum_steps)
            
            # 更新优化器参数
            self.optimizer.step()
            self.optimizer.zero_grad()
            
            # 记录损失值
            self.log(loss=loss.item())
            
            # 更新步数
            self.steps += 1
            
            # 如果存在验证数据加载器且满足验证频率条件
            if exists(self.valid_dataloader) and not (step % self.valid_every):
                self.wait()
                
                # 如果是主进程
                if self.accelerator.is_main_process:
                    total_valid_loss = 0.
                    total_batches = 0.
                    
                    # 将模型设置为评估模式
                    self.model.eval()
                    
                    # 在无梯度计算的情况下进行验证
                    with torch.no_grad():
                        for seq, prompt_len_or_mask in self.valid_dataloader:
                            batch = seq.shape[0]
                            
                            # 计算验证集的交叉熵损失
                            loss = self.get_cross_entropy_loss(seq, prompt_len_or_mask)
                            
                            total_valid_loss += loss.item() * batch
                            total_batches += batch
                    
                    # 计算验证集的平均损失
                    valid_loss = total_valid_loss / total_batches
                    
                    # 记录验证集损失值
                    self.log(valid_loss=valid_loss)
# 定义一个 DPODatasetGenerator 类,用于生成奖励数据集

class DPODatasetGenerator(Module):
    # 初始化方法,接受以下参数
    @beartype
    def __init__(
        self,
        model: Module,  # 模型对象
        prompt_dataset: Dataset,  # 提示数据集
        num_preference_pairs: int,  # 偏好对数量
        accelerator: Accelerator,  # 加速器对象
        tokenizer_encode: Callable[[str], TensorType['seq', int]],  # 编码器函数
        tokenizer_decode: Callable[[TensorType['seq', int]], str],  # 解码器函数
        self_reward_model: Optional[Module] = None,  # 自我奖励模型,默认为 None
        batch_size: int = 16,  # 批处理大小,默认为 16
        num_candidate_responses: int = 4,  # 候选响应数量,默认为 4
        gen_temperature: float = 0.7,  # 生成温度,默认为 0.7
        gen_filter_fn = top_p,  # 生成过滤函数,默认为 top_p
        gen_filter_kwargs: dict = dict(thres = 0.9),  # 生成过滤函数的参数,默认为 {'thres': 0.9}
        eval_temperature: float = 0.7,  # 评估温度,默认为 0.7
        eval_filter_fn = top_p,  # 评估过滤函数,默认为 top_p
        eval_filter_kwargs: dict = dict(thres = 0.9),  # 评估过滤函数的参数,默认为 {'thres': 0.9}
        num_evals_to_average: int = 3,  # 平均评估次数,默认为 3
        *,
        reward_config: RewardConfig,  # 奖励配置对象
        reward_model: Optional[Module] = None,  # 奖励模型,默认为 None
        data_folder: str = './',  # 数据文件夹,默认为当前目录
        preference_seq_memmap_file: str = 'preference_seq.memmap.npy',  # 偏好序列内存映射文件名,默认为 'preference_seq.memmap.npy'
        prompt_len_memmap_file: str = 'prompt_len.memmap.npy',  # 提示长度内存映射文件名,默认为 'prompt_len.memmap.npy'
        self_reward_memmap_file: str = 'self_reward.memmap.npy',  # 自我奖励内存映射文件名,默认为 'self_reward.memmap.npy'
        preference_max_seq_len: int = 1024,  # 偏好最大序列长度,默认为 1024
        generate_reward_max_seq_len: int = 256,  # 生成奖励最大序列长度,默认为 256
        is_valid_reward: Callable[[float], bool] = lambda *args: True,  # 是否有效奖励的函数,默认为始终返回 True
        is_valid_reward_pair: Optional[Callable[[float, float], bool]] = None,  # 是否有效奖励对的函数,默认为 None
        pick_paired_rewards: Callable[[Tensor], Tensor] = default_pick_paired_rewards_fn,  # 选择配对奖励的函数,默认为 default_pick_paired_rewards_fn
        pad_id: int = -1  # 填充 ID,默认为 -1
    # 初始化函数,继承父类的初始化方法
    def __init__(
        self,
        model,
        num_candidate_responses,
        self_reward_model,
        reward_config,
        batch_size,
        prompt_dataset,
        gen_filter_fn,
        gen_filter_kwargs,
        gen_temperature,
        eval_filter_fn,
        eval_filter_kwargs,
        eval_temperature,
        tokenizer_encode,
        tokenizer_decode,
        num_evals_to_average,
        is_valid_reward,
        is_valid_reward_pair,
        pick_paired_rewards,
        reward_model,
        generate_reward_max_seq_len,
        num_preference_pairs,
        preference_max_seq_len,
        pad_id,
        data_folder,
        preference_seq_memmap_file,
        prompt_len_memmap_file,
        self_reward_memmap_file,
        accelerator
    ):
        # 调用父类的初始化方法
        super().__init__()

        # 初始化属性
        self.model = model
        self.num_candidate_responses = num_candidate_responses

        self.self_reward_model = default(self_reward_model, model)
        self.reward_config = reward_config.init()

        self.batch_size = batch_size
        self.prompt_dataset = prompt_dataset
        self.prompt_dataloader = DataLoader(prompt_dataset, batch_size = batch_size, shuffle = True)

        self.gen_filter_fn = gen_filter_fn
        self.gen_filter_kwargs = gen_filter_kwargs
        self.gen_temperature = gen_temperature

        self.eval_filter_fn = eval_filter_fn
        self.eval_filter_kwargs = eval_filter_kwargs
        self.eval_temperature = eval_temperature

        self.tokenizer_encode = cast_output(lambda t: t.long())(tokenizer_encode)
        self.tokenizer_decode = cast_input(lambda t: t.long() if torch.is_tensor(t) else [*map(int, t)])(tokenizer_decode)

        self.num_evals_to_average = num_evals_to_average

        # 逻辑用于采样奖励对并在将其添加到生成的偏好数据集之前进行验证

        self.is_valid_reward = is_valid_reward
        self.is_valid_reward_pair = default(is_valid_reward_pair, lambda *args: True)
        self.pick_paired_rewards = pick_paired_rewards

        # 准备外部奖励模型,如果传入的话

        self.has_external_reward_model = exists(reward_model)
        self.reward_model = reward_model

        # 形状和填充

        self.generate_reward_max_seq_len = generate_reward_max_seq_len

        self.num_preference_pairs = num_preference_pairs

        self.preference_max_seq_len = preference_max_seq_len

        self.pad_id = pad_id

        memmap_shape = (num_preference_pairs, 2, preference_max_seq_len)

        # 保存以便在最后返回 DPO 数据集的实例

        self.dpo_dataset_kwargs = dict(
            data_folder = data_folder,
            preference_seq_memmap_file = preference_seq_memmap_file,
            prompt_len_memmap_file = prompt_len_memmap_file
        )

        # npy 文件的 memmap

        self.data_folder_path = Path(data_folder)
        self.data_folder_path.mkdir(exist_ok = True, parents = True)

        self.preference_seq_memmap_path = self.data_folder_path / preference_seq_memmap_file
        self.prompt_len_memmap_path = self.data_folder_path / prompt_len_memmap_file
        self.self_reward_mmemap_path = self.data_folder_path / self_reward_memmap_file

        self.preference_seq_memmap = open_memmap(str(self.preference_seq_memmap_path), dtype = 'int', mode = 'w+', shape = memmap_shape)
        self.prompt_len_memmap = open_memmap(str(self.prompt_len_memmap_path), dtype = 'int', mode = 'w+', shape = (num_preference_pairs,))
        self.self_reward_memmap_file = open_memmap(str(self.self_reward_mmemap_path), dtype = 'float32', mode = 'w+', shape = (num_preference_pairs, 2))

        self.accelerator = accelerator

    # 返回加速器设备
    @property
    def device(self):
        return self.accelerator.device

    # 生成奖励
    def generate_reward(
        self,
        prompt: str,
        response: str
        ) -> Optional[float]:

        """
        main contribution of the paper is the logic in this function
        in paper, they sample it 3 times and then average
        """

        # 获取模型参数的设备信息
        device = next(self.model.parameters()).device

        # 获取奖励配置中的模板函数和解析奖励函数
        template_fn = self.reward_config.template_fn
        parse_reward = self.reward_config.parse_reward

        # 根据模板函数生成奖励提示字符串,并使用分词器编码为张量
        reward_prompt_str = template_fn(prompt=prompt, response=response)
        reward_prompt = self.tokenizer_encode(reward_prompt_str).to(device)

        # 复制奖励提示张量,重复次数为 self.num_evals_to_average
        reward_prompt = repeat(reward_prompt, 'n -> b n', b=self.num_evals_to_average)

        reward_prompt = reward_prompt.to(device)
        self_reward_model = self_reward_model.to(device)

        # 使用自我奖励模型生成奖励响应
        reward_responses = sample(
            self_reward_model,
            prompts=reward_prompt,
            seq_len=self.generate_reward_max_seq_len,
            temperature=self.eval_temperature,
            filter_fn=self.eval_filter_fn,
            filter_kwargs=self.eval_filter_kwargs
        )

        # 将奖励响应转换为字符串列表
        reward_responses_as_str: List[str] = [self.tokenizer_decode(resp[resp != self.pad_id].cpu()) for resp in reward_responses]
        
        # 解析奖励字符串列表,得到奖励值列表
        rewards: List[Optional[float]] = [parse_reward(resp_str) for resp_str in reward_responses_as_str]

        # 过滤掉不存在的奖励值
        rewards = [*filter(exists, rewards)] # for now, just filter out any failed responses

        # 如果奖励值列表为空,则返回 None
        if len(rewards) == 0:
            return None

        # 计算奖励值列表的平均值
        avg_reward = Tensor(rewards).mean().item()
        return avg_reward

    @torch.no_grad()
# 定义了一个类 FinetuneConfig,用于存储微调配置信息
class FinetuneConfig:
    pass

# 导入 partial 函数,用于创建带有默认参数的函数
default_dict = partial(field, default_factory = dict)

# 使用 dataclass 装饰器定义了一个类 SFTConfig,继承自 FinetuneConfig,用于存储自训练配置信息
@dataclass
class SFTConfig(FinetuneConfig):
    train_dataset: Union[Dataset, List[Dataset]]  # 训练数据集,可以是单个数据集或数据集列表
    valid_dataset: Optional[Dataset] = None  # 验证数据集,默认为 None
    dropout: float = 0.1  # dropout 概率,默认为 0.1
    trainer_kwargs: dict = default_dict()  # 训练器参数,默认为一个空字典

# 使用 dataclass 装饰器定义了一个类 SelfRewardDPOConfig,继承自 FinetuneConfig,用于存储自奖励 DPO 配置信息
@dataclass
class SelfRewardDPOConfig(FinetuneConfig):
    prompt_dataset: Dataset  # 提示数据集
    num_generated_preference_pairs: int  # 生成的偏好对数量
    dpo_beta: float = 0.1  # DPO beta 参数,默认为 0.1
    max_seq_len: int = 1024  # 最大序列长度,默认为 1024
    rewarding_model: Optional[Module] = None  # 奖励模型,默认为 None
    self_reward_config_keyname: str = 'default'  # 自奖励配置键名,默认为 'default'
    is_valid_reward: Callable[[float], bool] = lambda reward: reward >= 0  # 验证奖励是否有效的函数,默认为 lambda 函数
    is_valid_reward_pair: Callable[[Tensor, Tensor], bool] = default_is_valid_reward_pair  # 验证奖励对是否有效的函数
    pick_paired_rewards_fn: Callable[[Tensor], Tensor] = default_pick_paired_rewards_fn  # 选择配对奖励的函数
    dropout: float = 0.1  # dropout 概率,默认为 0.1
    early_stopper_eval_module: Optional[Module] = None  # 早停评估模块,默认为 None
    num_train_steps: Optional[Module] = None  # 训练步数,默认为 None
    num_candidate_responses: int = 4  # 候选响应数量,默认为 4
    num_sampled_reward_responses: int = 3  # 采样奖励响应数量,默认为 3
    gen_temperature: float = 0.7  # 生成温度,默认为 0.7
    gen_filter_fn: Callable = top_p  # 生成过滤函数,默认为 top_p
    gen_filter_kwargs: dict = default_dict()  # 生成过滤函数参数,默认为一个空字典
    eval_temperature: float = 0.7  # 评估温度,默认为 0.7
    eval_filter_fn: Callable = top_p  # 评估过滤函数,默认为 top_p
    eval_filter_kwargs: dict = default_dict()  # 评估过滤函数参数,默认为一个空字典
    trainer_kwargs: dict = field(default_factory = dict)  # 训练器参数,默认为一个空字典
    reward_generator_kwargs: dict = default_dict()  # 奖��生成器参数,默认为一个空字典

# 使用 dataclass 装饰器定义了一个类 ExternalRewardDPOConfig,继承自 FinetuneConfig,用于存储外部奖励 DPO 配置信息
@dataclass
class ExternalRewardDPOConfig(FinetuneConfig):
    reward_model: Module  # 奖励模型
    dpo_beta: float = 0.1  # DPO beta 参数,默认为 0.1
    max_seq_len: int = 1024  # 最大序列长度,默认为 1024
    gen_temperature: float = 0.7  # 生成温度,默认为 0.7
    gen_filter_fn: Callable = top_p  # 生成过滤函数,默认为 top_p
    gen_filter_kwargs: dict = default_dict()  # 生成过滤函数参数,默认为一个空字典
    dropout: float = 0.1  # dropout 概率,默认为 0.1
    trainer_kwargs: dict = default_dict()  # 训练器参数,默认为一个空字典
    reward_generator_kwargs: dict = default_dict()  # 奖励生成器参数,默认为一个空字典

# 使用 dataclass 装饰器定义了一个类 SelfPlayConfig,继承自 FinetuneConfig,用于存储自对弈配置信息
@dataclass
class SelfPlayConfig(FinetuneConfig):
    train_dataset: Dataset  # 训练数据集
    valid_dataset: Optional[Dataset] = None  # 验证数据集,默认为 None
    max_seq_len: int = 1024  # 最大序列长度,默认为 1024
    spin_λ: float = 0.1  # spin_λ 参数,默认为 0.1
    dropout: float = 0.1  # dropout 概率,默认为 0.1
    temperature: float = 0.7  # 温度,默认为 0.7
    filter_fn: Callable = top_p  # 过滤函数,默认为 top_p
    filter_kwargs: dict = default_dict()  # 过滤函数参数,默认为一个空字典
    trainer_kwargs: dict = default_dict()  # 训练器参数,默认为一个空字典
    spin_kwargs: dict =  default_dict()  # spin 参数,默认为一个空字典

# 定义了一个函数 create_default_paper_config,用于生成默认的论文配置信息
@beartype
def create_default_paper_config(
    *,
    train_sft_dataset: Union[Dataset, List[Dataset],  # 训练 SFT 数据集,可以是单个数据集或数据集列表
    self_reward_prompt_dataset: Union[Dataset, Tuple[Dataset, Dataset]],  # 自奖励提示数据集,可以是单个数据集或数据集元组
    valid_sft_dataset: Optional[Dataset] = None,  # 验证 SFT 数据集,默认为 None
    num_generated_preference_pairs = (3964, 6942),  # 生成的偏好对数量,默认为 (3964, 6942)
    early_stopper_eval_module: Optional[Module] = None,  # 早停评估模块,默认为 None
    dpo_num_train_steps: Optional[int] = None,  # DPO 训练步数,默认为 None
    sft_config: dict = dict(),  # SFT 配置信息,默认为一个空字典
    self_reward_config: dict = dict()  # 自奖励配置信息,默认为一个空字典
) -> List[FinetuneConfig]:  # 返回值为 FinetuneConfig 类型的列表

    prompt_dataset_iter1, prompt_dataset_iter2 = cast_tuple(self_reward_prompt_dataset, 2, validate = True)  # 将自奖励提示数据集转换为元组
    num_generated_iter1, num_generated_iter2 = num_generated_preference_pairs  # 解包生成的偏好对数量

    return [
        SFTConfig(
            train_dataset = train_sft_dataset,  # 训练 SFT 数据集
            valid_dataset = valid_sft_dataset,  # 验证 SFT 数据集
            **sft_config  # 其他 SFT 配置信息
        ),
        SelfRewardDPOConfig(
            num_generated_preference_pairs = num_generated_iter1,  # 生成的偏好对数量
            prompt_dataset = prompt_dataset_iter1,  # 提示数据集
            num_train_steps = dpo_num_train_steps,  # DPO 训练步数
            early_stopper_eval_module = early_stopper_eval_module,  # 早停评估模块
            **self_reward_config  # 其他自奖励配置信息
        ),
        SelfRewardDPOConfig(
            num_generated_preference_pairs = num_generated_iter2,  # 生成的偏好对数量
            prompt_dataset = prompt_dataset_iter2,  # 提示数据集
            num_train_steps = dpo_num_train_steps,  # DPO 训练步数
            early_stopper_eval_module = early_stopper_eval_module,  # 早停评估模块
            **self_reward_config  # 其他自奖励配置信息
        )
    ]

# 定义了一个类 SelfRewardingTrainer,继承自 Module,用于自奖励训练
class SelfRewardingTrainer(Module):
    @beartype  # 类型注解装饰器
    # 初始化方法,接受模型、微调配置、编码和解码函数等参数
    def __init__(
        self,
        model: Module,
        *,
        finetune_configs: Union[Dict, List[FinetuneConfig]],
        tokenizer_encode: Callable[[str], TensorType['seq', int]],
        tokenizer_decode: Callable[[TensorType['seq', int]], str],
        self_reward_prompt_config: Union[RewardConfig, Dict[str, RewardConfig]] = SELF_REWARD_PROMPT_CONFIG,
        pad_id: int = -1,
        checkpoints_folder: str = './checkpoints',
        accelerate_kwargs: dict = dict()
    # 获取未加速的模型
    @property
    def unwrapped_model(self):
        return self.accelerator.unwrap_model(self.model)

    # 打印方法,用于输出信息
    def print(self, *msg):
        self.accelerator.print(*msg)

    # 等待方法,等待所有进程完成
    def wait(self):
        return self.accelerator.wait_for_everyone()

    # 保存方法,保存模型参数到指定路径
    def save(self, path: str, overwrite: bool = False):
        self.wait()

        # 如果是主进程
        if self.accelerator.is_main_process:

            # 拼接保存路径
            path = self.checkpoints_folder / path

            # 如果文件已存在且不允许覆盖,则报错
            assert not path.exists() or overwrite, f'file already exists'

            # 封装模型参数并保存
            pkg = dict(
                model = self.unwrapped_model.state_dict()
            )

            torch.save(pkg, str(path))

    # 前向传播方法,用于微调训练
    def forward(
        self,
        overwrite_checkpoints: bool = False
    ):

        # 遍历训练器类型和训练器
        for ind, (trainer_type, trainer) in enumerate(self.trainers):
            finetuning_stage = ind + 1
            trainer()

            # 保存微调阶段的模型参数
            self.save(f'{finetuning_stage}.{trainer_type}.ckpt.pt', overwrite = overwrite_checkpoints)

        # 输出训练完成信息
        self.print(f'self-reward training done')
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值