.\lucidrains\kronecker-attention-pytorch\kronecker_attention_pytorch\__init__.py
# 从 kronecker_attention_pytorch 模块中导入 KroneckerSelfAttention 类
from kronecker_attention_pytorch.kronecker_attention_pytorch import KroneckerSelfAttention
Kronecker Attention Pytorch
Implementation of Kronecker Attention in Pytorch. Results look less than stellar, but if someone found some context where this architecture works well, please post in the issues and let everyone know.
Install
$ pip install kronecker_attention_pytorch
Usage
import torch
from kronecker_attention_pytorch import KroneckerSelfAttention
attn = KroneckerSelfAttention(
chan = 32,
heads = 8,
dim_heads = 64
)
x = torch.randn(1, 32, 256, 512)
attn(x) # (1, 32, 256, 512)
Citations
@article{Gao_2020,
title={Kronecker Attention Networks},
url={http://dx.doi.org/10.1145/3394486.3403065},
journal={Proceedings of the 26th ACM SIGKDD International Conference on Knowledge Discovery & Data Mining},
publisher={ACM},
author={Gao, Hongyang and Wang, Zhengyang and Ji, Shuiwang},
year={2020},
month={Aug}
}
.\lucidrains\kronecker-attention-pytorch\setup.py
# 导入设置工具和查找包的函数
from setuptools import setup, find_packages
# 设置包的元数据
setup(
name = 'kronecker-attention-pytorch', # 包的名称
packages = find_packages(), # 查找所有包
version = '0.0.6', # 版本号
license='MIT', # 许可证
description = 'Kronecker Attention - Pytorch', # 描述
author = 'Phil Wang', # 作者
author_email = 'lucidrains@gmail.com', # 作者邮箱
url = 'https://github.com/lucidrains/kronecker-attention-pytorch', # 项目链接
keywords = [
'artificial intelligence', # 关键词:人工智能
'attention mechanism' # 关键词:注意力机制
],
install_requires=[
'torch', # 安装依赖:torch
'einops>=0.3' # 安装依赖:einops 版本大于等于0.3
],
classifiers=[
'Development Status :: 4 - Beta', # 分类:开发状态为Beta
'Intended Audience :: Developers', # 分类:面向的受众为开发者
'Topic :: Scientific/Engineering :: Artificial Intelligence', # 分类:主题为科学/工程 - 人工智能
'License :: OSI Approved :: MIT License', # 分类:许可证为MIT
'Programming Language :: Python :: 3.6', # 分类:编程语言为Python 3.6
],
)
.\lucidrains\lambda-networks\lambda_networks\lambda_networks.py
import torch
from torch import nn, einsum
from einops import rearrange
# helpers functions
# 检查值是否存在
def exists(val):
return val is not None
# 如果值存在则返回该值,否则返回默认值
def default(val, d):
return val if exists(val) else d
# 计算相对位置编码
def calc_rel_pos(n):
# 生成网格坐标
pos = torch.meshgrid(torch.arange(n), torch.arange(n))
# 重新排列坐标
pos = rearrange(torch.stack(pos), 'n i j -> (i j) n') # [n*n, 2] pos[n] = (i, j)
# 计算相对位置
rel_pos = pos[None, :] - pos[:, None] # [n*n, n*n, 2] rel_pos[n, m] = (rel_i, rel_j)
rel_pos += n - 1 # 将值范围从[-n+1, n-1]转换为[0, 2n-2]
return rel_pos
# lambda layer
class LambdaLayer(nn.Module):
def __init__(
self,
dim,
*,
dim_k,
n = None,
r = None,
heads = 4,
dim_out = None,
dim_u = 1):
super().__init__()
dim_out = default(dim_out, dim)
self.u = dim_u # intra-depth dimension
self.heads = heads
assert (dim_out % heads) == 0, 'values dimension must be divisible by number of heads for multi-head query'
dim_v = dim_out // heads
# 定义卷积层
self.to_q = nn.Conv2d(dim, dim_k * heads, 1, bias = False)
self.to_k = nn.Conv2d(dim, dim_k * dim_u, 1, bias = False)
self.to_v = nn.Conv2d(dim, dim_v * dim_u, 1, bias = False)
# 定义归一化层
self.norm_q = nn.BatchNorm2d(dim_k * heads)
self.norm_v = nn.BatchNorm2d(dim_v * dim_u)
# 检查是否存在局部上下文
self.local_contexts = exists(r)
if exists(r):
assert (r % 2) == 1, 'Receptive kernel size should be odd'
self.pos_conv = nn.Conv3d(dim_u, dim_k, (1, r, r), padding = (0, r // 2, r // 2))
else:
assert exists(n), 'You must specify the window size (n=h=w)'
rel_lengths = 2 * n - 1
self.rel_pos_emb = nn.Parameter(torch.randn(rel_lengths, rel_lengths, dim_k, dim_u))
self.rel_pos = calc_rel_pos(n)
def forward(self, x):
b, c, hh, ww, u, h = *x.shape, self.u, self.heads
q = self.to_q(x)
k = self.to_k(x)
v = self.to_v(x)
q = self.norm_q(q)
v = self.norm_v(v)
q = rearrange(q, 'b (h k) hh ww -> b h k (hh ww)', h = h)
k = rearrange(k, 'b (u k) hh ww -> b u k (hh ww)', u = u)
v = rearrange(v, 'b (u v) hh ww -> b u v (hh ww)', u = u)
k = k.softmax(dim=-1)
λc = einsum('b u k m, b u v m -> b k v', k, v)
Yc = einsum('b h k n, b k v -> b h v n', q, λc)
if self.local_contexts:
v = rearrange(v, 'b u v (hh ww) -> b u v hh ww', hh = hh, ww = ww)
λp = self.pos_conv(v)
Yp = einsum('b h k n, b k v n -> b h v n', q, λp.flatten(3))
else:
n, m = self.rel_pos.unbind(dim = -1)
rel_pos_emb = self.rel_pos_emb[n, m]
λp = einsum('n m k u, b u v m -> b n k v', rel_pos_emb, v)
Yp = einsum('b h k n, b n k v -> b h v n', q, λp)
Y = Yc + Yp
out = rearrange(Y, 'b h v (hh ww) -> b (h v) hh ww', hh = hh, ww = ww)
return out
.\lucidrains\lambda-networks\lambda_networks\tfkeras.py
import tensorflow as tf
from einops.layers.tensorflow import Rearrange
from tensorflow.keras.layers import Conv2D, BatchNormalization, Conv3D, ZeroPadding3D, Softmax, Lambda, Add, Layer
from tensorflow.keras import initializers
from tensorflow import einsum, nn, meshgrid
# 导入所需的库
# helpers functions
def exists(val):
return val is not None
# 检查值是否存在
def default(val, d):
return val if exists(val) else d
# 如果值存在则返回该值,否则返回默认值
def calc_rel_pos(n):
pos = tf.stack(meshgrid(tf.range(n), tf.range(n), indexing = 'ij'))
pos = Rearrange('n i j -> (i j) n')(pos) # 重新排列位置信息
rel_pos = pos[None, :] - pos[:, None] # 计算相对位置
rel_pos += n - 1 # 调整值范围
return rel_pos
# 计算相对位置信息
# lambda layer
class LambdaLayer(Layer):
def __init__(
self,
*,
dim_k,
n = None,
r = None,
heads = 4,
dim_out = None,
dim_u = 1):
super(LambdaLayer, self).__init__()
self.out_dim = dim_out
self.u = dim_u # intra-depth dimension
self.heads = heads
assert (dim_out % heads) == 0, 'values dimension must be divisible by number of heads for multi-head query'
self.dim_v = dim_out // heads
self.dim_k = dim_k
self.heads = heads
self.to_q = Conv2D(self.dim_k * heads, 1, use_bias=False)
self.to_k = Conv2D(self.dim_k * dim_u, 1, use_bias=False)
self.to_v = Conv2D(self.dim_v * dim_u, 1, use_bias=False)
self.norm_q = BatchNormalization()
self.norm_v = BatchNormalization()
self.local_contexts = exists(r)
if exists(r):
assert (r % 2) == 1, 'Receptive kernel size should be odd'
self.pos_conv = Conv3D(dim_k, (1, r, r), padding='same')
else:
assert exists(n), 'You must specify the window length (n = h = w)'
rel_length = 2 * n - 1
self.rel_pos_emb = self.add_weight(name='pos_emb',
shape=(rel_length, rel_length, dim_k, dim_u),
initializer=initializers.random_normal,
trainable=True)
self.rel_pos = calc_rel_pos(n)
# 初始化 LambdaLayer 类
def call(self, x, **kwargs):
b, hh, ww, c, u, h = *x.get_shape().as_list(), self.u, self.heads
q = self.to_q(x)
k = self.to_k(x)
v = self.to_v(x)
q = self.norm_q(q)
v = self.norm_v(v)
q = Rearrange('b hh ww (h k) -> b h k (hh ww)', h=h)(q)
k = Rearrange('b hh ww (u k) -> b u k (hh ww)', u=u)(k)
v = Rearrange('b hh ww (u v) -> b u v (hh ww)', u=u)(v)
k = nn.softmax(k)
Lc = einsum('b u k m, b u v m -> b k v', k, v)
Yc = einsum('b h k n, b k v -> b n h v', q, Lc)
if self.local_contexts:
v = Rearrange('b u v (hh ww) -> b v hh ww u', hh=hh, ww=ww)(v)
Lp = self.pos_conv(v)
Lp = Rearrange('b v h w k -> b v k (h w)')(Lp)
Yp = einsum('b h k n, b v k n -> b n h v', q, Lp)
else:
rel_pos_emb = tf.gather_nd(self.rel_pos_emb, self.rel_pos)
Lp = einsum('n m k u, b u v m -> b n k v', rel_pos_emb, v)
Yp = einsum('b h k n, b n k v -> b n h v', q, Lp)
Y = Yc + Yp
out = Rearrange('b (hh ww) h v -> b hh ww (h v)', hh = hh, ww = ww)(Y)
return out
# 调用 LambdaLayer 类
def compute_output_shape(self, input_shape):
return (*input_shape[:2], self.out_dim)
# ���算输出形状
def get_config(self):
config = {'output_dim': (*self.input_shape[:2], self.out_dim)}
base_config = super(LambdaLayer, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
# 获取配置信息
.\lucidrains\lambda-networks\lambda_networks\__init__.py
# 从 lambda_networks 模块中导入 LambdaLayer 类
from lambda_networks.lambda_networks import LambdaLayer
# 将 LambdaLayer 类赋值给 λLayer 变量
λLayer = LambdaLayer
Lambda Networks - Pytorch
Implementation of λ Networks, a new approach to image recognition that reaches SOTA on ImageNet. The new method utilizes λ layer, which captures interactions by transforming contexts into linear functions, termed lambdas, and applying these linear functions to each input separately.
Install
$ pip install lambda-networks
Usage
Global context
import torch
from lambda_networks import LambdaLayer
layer = LambdaLayer(
dim = 32, # channels going in
dim_out = 32, # channels out
n = 64, # size of the receptive window - max(height, width)
dim_k = 16, # key dimension
heads = 4, # number of heads, for multi-query
dim_u = 1 # 'intra-depth' dimension
)
x = torch.randn(1, 32, 64, 64)
layer(x) # (1, 32, 64, 64)
Localized context
import torch
from lambda_networks import LambdaLayer
layer = LambdaLayer(
dim = 32,
dim_out = 32,
r = 23, # the receptive field for relative positional encoding (23 x 23)
dim_k = 16,
heads = 4,
dim_u = 4
)
x = torch.randn(1, 32, 64, 64)
layer(x) # (1, 32, 64, 64)
For fun, you can also import this as follows
from lambda_networks import λLayer
Tensorflow / Keras version
Shinel94 has added a Keras implementation! It won’t be officially supported in this repository, so either copy / paste the code under ./lambda_networks/tfkeras.py
or make sure to install tensorflow
and keras
before running the following.
import tensorflow as tf
from lambda_networks.tfkeras import LambdaLayer
layer = LambdaLayer(
dim_out = 32,
r = 23,
dim_k = 16,
heads = 4,
dim_u = 1
)
x = tf.random.normal((1, 64, 64, 16)) # channel last format
layer(x) # (1, 64, 64, 32)
Citations
@inproceedings{
anonymous2021lambdanetworks,
title={LambdaNetworks: Modeling long-range Interactions without Attention},
author={Anonymous},
booktitle={Submitted to International Conference on Learning Representations},
year={2021},
url={https://openreview.net/forum?id=xTJEN-ggl1b},
note={under review}
}
.\lucidrains\lambda-networks\setup.py
# 导入设置和查找包的函数
from setuptools import setup, find_packages
# 设置包的信息
setup(
name = 'lambda-networks', # 包的名称
packages = find_packages(), # 查找所有包
version = '0.4.0', # 版本号
license='MIT', # 许可证
description = 'Lambda Networks - Pytorch', # 描述
author = 'Phil Wang', # 作者
author_email = 'lucidrains@gmail.com', # 作者邮箱
url = 'https://github.com/lucidrains/lambda-networks', # 项目链接
keywords = [
'artificial intelligence', # 关键词:人工智能
'attention mechanism', # 关键词:注意力机制
'image recognition' # 关键词:图像识别
],
install_requires=[
'torch>=1.6', # 安装所需的 torch 版本
'einops>=0.3' # 安装所需的 einops 版本
],
classifiers=[
'Development Status :: 4 - Beta', # 分类:开发状态为 Beta
'Intended Audience :: Developers', # 分类:面向的受众为开发者
'Topic :: Scientific/Engineering :: Artificial Intelligence', # 分类:主题为科学/工程 - 人工智能
'License :: OSI Approved :: MIT License', # 分类:许可证为 MIT
'Programming Language :: Python :: 3.6', # 分类:编程语言为 Python 3.6
],
)
.\lucidrains\learning-to-expire-pytorch\learning_to_expire_pytorch\learning_to_expire_pytorch.py
# 导入 torch 库
import torch
# 从 torch 库中导入 nn 模块和 einsum 函数
from torch import nn, einsum
# 从 torch.nn.functional 中导入 F 模块
import torch.nn.functional as F
# 从 einops 库中导入 rearrange 和 repeat 函数
from einops import rearrange, repeat
# 从 collections 模块中导入 namedtuple 类
from collections import namedtuple
# 定义一个命名元组 Memory,包含 mems 和 elapsed_times 两个字段
Memory = namedtuple('Memory', ['mems', 'elapsed_times'])
# 辅助函数
# 判断变量是否存在
def exists(val):
return val is not None
# 如果变量存在则返回其值,否则返回默认值
def default(val, d):
return val if exists(val) else d
# 安全地拼接张量
def safe_cat(tensors, dim = -1):
tensors = list(filter(exists, tensors))
if len(tensors) == 1:
return tensors[0]
return torch.cat(tensors, dim = dim)
# 安全地对张量进行加法操作
def safe_add(tensor, n):
if not exists(tensor):
return None
return tensor + n
# 位置嵌入
# 相对位移函数
def rel_shift(t):
b, h, i, j, device, dtype = *t.shape, t.device, t.dtype
zero_pad = torch.zeros((b, h, i, 1), device = device, dtype = dtype)
concatted = torch.cat([zero_pad, t], dim = -1)
shifted = concatted.view(b, h, j + 1, i)[:, :, 1:]
return shifted.view_as(t)
# 正弦嵌入类
class SinusoidalEmbedding(nn.Module):
def __init__(self, dim):
super().__init__()
inv_freq = 1. / (10000 ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer('inv_freq', inv_freq)
def forward(self, x):
n, device = x.shape[1], x.device
t = torch.arange(n - 1, -1, -1, device = device).type_as(self.inv_freq)
sinusoid_inp = einsum('i , j -> i j', t, self.inv_freq)
emb = torch.cat((sinusoid_inp.sin(), sinusoid_inp.cos()), dim = -1)
return emb
# 过期时间跨度逻辑
# 过期时间跨度类
class ExpireSpan(nn.Module):
def __init__(self, dim, max_mem_len, ramp_length):
super().__init__()
self.max_mem_len = max_mem_len
self.ramp_length = ramp_length
self.to_expiration = nn.Linear(dim, 1)
nn.init.constant_(self.to_expiration.bias.data, val = -self.max_mem_len)
def forward(self, mem, time, seq_len):
exps = self.to_expiration(mem).squeeze(-1).sigmoid() * self.max_mem_len
exps = rearrange(exps, 'b j -> b () () j')
t = rearrange(time, 'b j -> b () () j')
r = F.pad(exps - t, (0, seq_len), value = 1.)
mask = torch.clamp((r / self.ramp_length) + 1, min = 0., max = 1.)
return exps, mask
# 类
# 预层归一化类
class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.fn = fn
def forward(self, x, **kwargs):
x = self.norm(x)
return self.fn(x, **kwargs)
# 前馈神经网络类
class FeedForward(nn.Module):
def __init__(self, dim, mult = 4):
super().__init__()
self.net = nn.Sequential(
nn.Linear(dim, dim * mult),
nn.GELU(),
nn.Linear(dim * mult, dim)
)
def forward(self, x):
return self.net(x)
# 因果注意力类
class CausalAttention(nn.Module):
def __init__(self, dim, heads = 8):
super().__init__()
dim_head = dim // heads
self.heads = heads
self.scale = dim_head ** -0.5
self.to_pos = nn.Linear(dim, dim_head)
self.to_q = nn.Linear(dim, dim)
self.to_kv = nn.Linear(dim, dim * 2)
self.to_out = nn.Linear(dim, dim)
# 定义一个前向传播函数,接受输入 x,位置编码 pos_emb,记忆 mem,默认为 None,过期掩码 expire_mask,默认为 None
def forward(self, x, pos_emb, mem = None, expire_mask = None):
# 获取输入 x 的维度信息:n 为序列长度,h 为头数,scale 为缩放因子,device 为设备信息
n, h, scale, device = x.shape[1], self.heads, self.scale, x.device
# 将输入 x 转换为查询向量 q
q = self.to_q(x)
# 如果存在记忆 mem,则获取其长度,否则记忆长度为 0
mem_len = mem.shape[1] if exists(mem) else 0
# 将记忆 mem 和输入 x 拼接在一起,形成上下文 context
context = safe_cat((mem, x), dim = 1)
# 将上下文 context 转换为键值对 kv,并按键值对拆分为 k 和 v
kv = self.to_kv(context).chunk(2, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, *kv))
# 计算点积注意力得分 dots
dots = einsum('b h i d, b h j d -> b h i j', q, k) * scale
# 计算相对位置贡献
pos = self.to_pos(pos_emb)
pos_dots = einsum('b h i d, j d -> b h i j', q, pos) * scale
pos_dots = rel_shift(pos_dots)
pos_dots = F.pad(pos_dots, (mem_len, 0), value = 0)
dots += pos_dots
# 生成因果掩码
mask = torch.ones(dots.shape[-2:], device = device).triu_(mem_len + 1).bool()
mask = rearrange(mask, 'i j -> () () i j')
dots.masked_fill_(mask, float('-inf'))
del mask
# 计算注意力权重
attn = dots.softmax(dim = -1)
# 如果存在过期掩码,则将注意力权重乘以过期掩码
if exists(expire_mask):
attn = attn * expire_mask
# 计算输出
out = einsum('b h i j, b h j d -> b h i d', attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
# 定义一个名为 ExpireSpanTransformerXL 的类,继承自 nn.Module
class ExpireSpanTransformerXL(nn.Module):
# 初始化函数,接受多个参数
def __init__(
self,
*,
num_tokens, # 标记的数量
dim, # 向量的维度
depth, # 模型的深度
seq_len, # 序列的长度
heads = 8, # 多头注意力机制的头数,默认为 8
num_memory_blocks = 10, # 记忆块的数量,默认为 10
expire_loss_coef = 1e-6, # 过期损失系数,默认为 1e-6
ramp_length = 128): # 渐变长度,默认为 128
super().__init__() # 调用父类的初始化函数
# 创建一个标记嵌入层,将标记映射到指定维度的向量
self.token_emb = nn.Embedding(num_tokens, dim)
# 创建一个正弦嵌入层,用于添加正弦位置编码
self.sinusoidal_emb = SinusoidalEmbedding(dim)
self.dim = dim # 将维度赋值给类属性
self.depth = depth # 将深度赋值给类属性
self.seq_len = seq_len # 将序列长度赋值给类属性
self.max_mem_len = num_memory_blocks * seq_len # 计算最大记忆长度
self.expire_loss_coef = expire_loss_coef # 将过期损失系数赋值给类属性
self.layers = nn.ModuleList([]) # 创建一个空的模块列表
# 循环创建深度次数的层,并添加到模块列表中
for _ in range(depth):
self.layers.append(nn.ModuleList([
ExpireSpan(dim, self.max_mem_len, ramp_length), # 添加过期跨度模块
PreNorm(dim, CausalAttention(dim, heads = heads)), # 添加预归一化的因果注意力模块
PreNorm(dim, FeedForward(dim)), # 添加预归一化的前馈神经网络模块
]))
self.to_logits = nn.Linear(dim, num_tokens) # 创建一个线性层,将输出维度映射到标记数量
# 定义前向传播函数,接受输入 x 和记忆 memory,默认为 None
def forward(self, x, memory = None):
# 获取输入 x 的形状信息,包括 batch 大小 b,序列长度 n,维度 d,设备信息 device
b, n, d, device = *x.shape, self.dim, x.device
# 对输入 x 进行 token embedding
x = self.token_emb(x)
# 生成位置编码
pos_emb = self.sinusoidal_emb(x)
hidden_states = []
expire_masks_layers = []
# 如果存在记忆,则获取记忆中的 mems 和 elapsed_times,否则初始化为 None
mems_layers = memory.mems if exists(memory) else ((None,) * self.depth)
times_layers = memory.elapsed_times if exists(memory) else ((None,) * self.depth)
# 初始化辅助损失为 0
aux_loss = torch.tensor(0., requires_grad = True)
# 遍历每个层的记忆和时间信息,以及每个层的注意力和前馈网络
for (mem, time, (expire_span, attn, ff)) in zip(mems_layers, times_layers, self.layers):
hidden_states.append(x)
# 计算过期时间和过期掩码
exps, expire_mask = expire_span(mem, time, seq_len = n) if exists(mem) else (None, None)
expire_masks_layers.append(expire_mask)
# 训练模式下,根据时间信息生成遗忘掩码
if self.training and exists(time):
forget_time_thres = torch.randint(0, self.max_mem_len, (b, 1), device = device)
forget_dropout_mask = (time < forget_time_thres).float()
forget_dropout_mask = rearrange(forget_dropout_mask, 'b n -> b () () n')
forget_dropout_mask = F.pad(forget_dropout_mask, (0, n), value = 1.)
expire_mask *= forget_dropout_mask
# 执行注意力和前馈网络操作
x = attn(x, pos_emb = pos_emb, mem = mem, expire_mask = expire_mask) + x
x = ff(x) + x
if exists(exps):
# 计算辅助损失,仅对产生软掩码值的过期进行 L1 辅助损失
expiring_exps_mask = (expire_mask > 0) & (expire_mask < 1.)
expiring_exps = exps.masked_select(expiring_exps_mask[..., :-n])
aux_loss = aux_loss + (expiring_exps / self.seq_len).sum() * self.expire_loss_coef
# 生成最终的 logits
logits = self.to_logits(x)
# 如果序列长度等于 n
if self.seq_len == n:
if exists(expire_mask):
mems_layers_new = []
times_layers_new = []
# 遍���每个层的记忆、时间和过期掩码信息
for mems, times, expire_mask in zip(mems_layers, times_layers, expire_masks_layers):
expire_mask = rearrange(expire_mask, 'b () () i -> b i')
# 丢弃已过期的记忆
expired_exps_mask = (expire_mask <= 0)[..., :-n]
num_to_expire = min(expired_exps_mask.sum(dim = -1)
_, indices = expired_exps_mask.float().topk(k = num_to_expire, dim = -1)
even_expired_exps_mask = torch.zeros_like(expired_exps_mask, device = device).scatter(-1, indices, 1.).bool()
mems = mems.masked_select(~even_expired_exps_mask.unsqueeze(-1))
mems = mems.reshape(b, -1, d)
mems_layers_new.append(mems)
times = times.masked_select(~even_expired_exps_mask)
times = times.reshape(b, -1)
times_layers_new.append(times)
mems_layers = mems_layers_new
times_layers = times_layers_new
# 更新记忆和时间信息
new_memories = map(lambda t: safe_cat(t, dim = 1), list(zip(mems_layers, hidden_states)))
new_memories = map(lambda t: t[:, -self.max_mem_len:].detach(), new_memories)
new_times = torch.arange(n - 1, -1, -1, device = device)
new_times = repeat(new_times, 'n -> b n', b = b)
new_elapsed_times = map(lambda t: safe_cat((safe_add(t, n), new_times), dim = 1), times_layers)
new_elapsed_times = map(lambda t: t[-self.max_mem_len:], new_elapsed_times)
memory = Memory(list(new_memories), list(new_elapsed_times))
# 返回 logits、memory 和辅助损失
return logits, memory, aux_loss
.\lucidrains\learning-to-expire-pytorch\learning_to_expire_pytorch\__init__.py
# 从 learning_to_expire_pytorch.learning_to_expire_pytorch 模块中导入 ExpireSpanTransformerXL 类
from learning_to_expire_pytorch.learning_to_expire_pytorch import ExpireSpanTransformerXL
Learning to Expire - Pytorch (wip)
An implementation of Transformer with Expire-Span, a proposed technique for learning which memories to retain for long-range learning in attention-based networks.
Citations
@inproceedings{
anonymous2021not,
title={Not All Memories are Created Equal: Learning to Expire},
author={Anonymous},
booktitle={Submitted to International Conference on Learning Representations},
year={2021},
url={https://openreview.net/forum?id=ZVBtN6B_6i7},
note={under review}
}
.\lucidrains\learning-to-expire-pytorch\setup.py
# 导入设置工具和查找包工具
from setuptools import setup, find_packages
# 设置包的信息
setup(
# 包名
name = 'learning-to-expire-pytorch',
# 查找包,排除 examples 文件夹
packages = find_packages(exclude=['examples']),
# 版本号
version = '0.0.2',
# 许可证
license='MIT',
# 描述
description = 'Learning to Expire - Pytorch',
# 作者
author = 'Phil Wang',
# 作者邮箱
author_email = 'lucidrains@gmail.com',
# 项目链接
url = 'https://github.com/lucidrains/learning-to-expire-pytorch',
# 关键词
keywords = [
'artificial intelligence',
'attention mechanism',
'transformers',
'memory'
],
# 安装依赖
install_requires=[
'torch>=1.6',
'einops>=0.3'
],
# 分类
classifiers=[
'Development Status :: 4 - Beta',
'Intended Audience :: Developers',
'Topic :: Scientific/Engineering :: Artificial Intelligence',
'License :: OSI Approved :: MIT License',
'Programming Language :: Python :: 3.6',
],
)
.\lucidrains\lie-transformer-pytorch\lie_transformer_pytorch\lie_transformer_pytorch.py
# 导入数学库
import math
# 从 functools 库中导入 partial 函数
from functools import partial
# 导入 PyTorch 库
import torch
import torch.nn.functional as F
# 从 torch 库中导入 nn 模块和 einsum 函数
from torch import nn, einsum
# 从 lie_transformer_pytorch.se3 模块中导入 SE3 类
from lie_transformer_pytorch.se3 import SE3
# 从 einops 库中导入 rearrange 和 repeat 函数
from einops import rearrange, repeat
# 从 lie_transformer_pytorch.reversible 模块中导入 SequentialSequence 和 ReversibleSequence 类
# helpers
# 定义函数,判断变量是否存在
def exists(val):
return val is not None
# 定义函数,将变量转换为元组
def cast_tuple(val, depth):
return val if isinstance(val, tuple) else ((val,) * depth)
# 定义函数,返回默认值
def default(val, d):
return val if exists(val) else d
# 定义函数,对张量进行批量索引选择
def batched_index_select(values, indices, dim = 1):
value_dims = values.shape[(dim + 1):]
values_shape, indices_shape = map(lambda t: list(t.shape), (values, indices))
indices = indices[(..., *((None,) * len(value_dims))]
indices = indices.expand(*((-1,) * len(indices_shape)), *value_dims)
value_expand_len = len(indices_shape) - (dim + 1)
values = values[(*((slice(None),) * dim), *((None,) * value_expand_len), ...)]
value_expand_shape = [-1] * len(values.shape)
expand_slice = slice(dim, (dim + value_expand_len))
value_expand_shape[expand_slice] = indices.shape[expand_slice]
values = values.expand(*value_expand_shape)
dim += value_expand_len
return values.gather(dim, indices)
# helper classes
# 定义 Pass 类,用于对输入进行处理
class Pass(nn.Module):
def __init__(self, fn, dim = 1):
super().__init__()
self.fn = fn
self.dim = dim
def forward(self,x):
dim = self.dim
xs = list(x)
xs[dim] = self.fn(xs[dim])
return xs
# 定义 Lambda 类,用于对输入进行处理
class Lambda(nn.Module):
def __init__(self, fn):
super().__init__()
self.fn = fn
def forward(self, x):
return self.fn(x)
# 定义 GlobalPool 类,用于计算在掩码中所有空间位置(和群元素)上减少的值
class GlobalPool(nn.Module):
def __init__(self, mean = False):
super().__init__()
self.mean = mean
def forward(self, x):
coords, vals, mask = x
if not exists(mask):
return val.mean(dim = 1)
masked_vals = vals.masked_fill_(~mask[..., None], 0.)
summed = masked_vals.sum(dim = 1)
if not self.mean:
return summed
count = mask.sum(-1).unsqueeze(-1)
return summed / count
# subsampling code
# 定义 FPSindices 函数,用于根据距离矩阵和掩码进行下采样
def FPSindices(dists, frac, mask):
""" inputs: pairwise distances DISTS (bs,n,n), downsample_frac (float), valid atom mask (bs,n)
outputs: chosen_indices (bs,m) """
m = int(round(frac * dists.shape[1]))
bs, n, device = *dists.shape[:2], dists.device
dd_kwargs = {'device': device, 'dtype': torch.long}
B = torch.arange(bs, **dd_kwargs)
chosen_indices = torch.zeros(bs, m, **dd_kwargs)
distances = torch.ones(bs, n, device=device) * 1e8
a = torch.randint(0, n, (bs,), **dd_kwargs) # choose random start
idx = a % mask.sum(-1) + torch.cat([torch.zeros(1, **dd_kwargs), torch.cumsum(mask.sum(-1), dim=0)[:-1]], dim=0)
farthest = torch.where(mask)[1][idx]
for i in range(m):
chosen_indices[:, i] = farthest # add point that is farthest to chosen
dist = dists[B, farthest].masked_fill(~mask, -100) # (bs,n) compute distance from new point to all others
closer = dist < distances # if dist from new point is smaller than chosen points so far
distances[closer] = dist[closer] # update the chosen set's distance to all other points
farthest = torch.max(distances, -1)[1] # select the point that is farthest from the set
return chosen_indices
# 定义 FPSsubsample 类,用于进行 FPS 下采样
class FPSsubsample(nn.Module):
def __init__(self, ds_frac, cache = False, group = None):
super().__init__()
self.ds_frac = ds_frac
self.cache = cache
self.cached_indices = None
self.group = default(group, SE3())
# 获取查询索引,根据是否启用缓存和缓存文件是否存在来决定是否重新计算
def get_query_indices(self, abq_pairs, mask):
# 如果启用缓存并且缓存文件存在,则直接返回缓存的查询索引
if self.cache and exists(self.cached_indices):
return self.cached_indices
# 定义距离函数,如果存在分组则使用分组的距离函数,否则使用默认的 L2 范数
dist = self.group.distance if self.group else lambda ab: ab.norm(dim=-1)
# 计算 FPS 索引,根据数据集的分数和掩码值,返回索引值,并且将其从计算图中分离
value = FPSindices(dist(abq_pairs), self.ds_frac, mask).detach()
# 如果启用缓存,则将计算得到的索引值缓存起来
if self.cache:
self.cached_indices = value
# 返回计算得到的索引值
return value
# 前向传播函数,根据输入数据进行处理并返回结果
def forward(self, inp, withquery=False):
# 解包输入数据
abq_pairs, vals, mask, edges = inp
# 获取设备信息
device = vals.device
# 如果数据子采样比例不为1
if self.ds_frac != 1:
# 获取查询索引
query_idx = self.get_query_indices(abq_pairs, mask)
# 创建索引 B,用于索引操作
B = torch.arange(query_idx.shape[0], device=device).long()[:, None]
# 根据查询索引对 abq_pairs 进行子采样
subsampled_abq_pairs = abq_pairs[B, query_idx][B, :, query_idx]
# 根据查询索引对 vals 进行子采样
subsampled_values = batched_index_select(vals, query_idx, dim=1)
# 根据查询索引对 mask 进行子采样
subsampled_mask = batched_index_select(mask, query_idx, dim=1)
# 如果存在边信息,则根据查询索引对 edges 进行子采样
subsampled_edges = edges[B, query_idx][B, :, query_idx] if exists(edges) else None
else:
# 如果数据子采样比例为1,则不进行子采样操作
subsampled_abq_pairs = abq_pairs
subsampled_values = vals
subsampled_mask = mask
subsampled_edges = edges
query_idx = None
# 将子采样后的数据组合成元组
ret = (
subsampled_abq_pairs,
subsampled_values,
subsampled_mask,
subsampled_edges
)
# 如果需要查询索引信息,则将查询索引信息添加到返回结果中
if withquery:
ret = (*ret, query_idx)
# 返回处理后的结果
return ret
# 定义一个自注意力机制的类 LieSelfAttention
class LieSelfAttention(nn.Module):
def __init__(
self,
dim,
edge_dim = None,
group = None,
mc_samples = 32,
ds_frac = 1,
fill = 1 / 3,
dim_head = 64,
heads = 8,
cache = False
):
super().__init__()
self.dim = dim
# 设置用于估计卷积的样本数量
self.mc_samples = mc_samples
# 设置 LieConv 的等变性群
self.group = default(group, SE3())
# 注册缓冲区变量 r,用于本地邻域半径,由 fill 设置
self.register_buffer('r',torch.tensor(2.))
# 设置平均输入进入本地邻域的分数,决定 r
self.fill_frac = min(fill, 1.)
# 创建 FPSsubsample 对象,用于下采样
self.subsample = FPSsubsample(ds_frac, cache = cache, group = self.group)
# 内部系数,用于更新 r
self.coeff = .5
# 用于记录平均填充分数,仅用于日志记录
self.fill_frac_ema = fill
# 注意力相关参数
inner_dim = dim_head * heads
self.heads = heads
# 线性变换,用于计算查询、键、值和输出
self.to_q = nn.Linear(dim, inner_dim, bias = False)
self.to_k = nn.Linear(dim, inner_dim, bias = False)
self.to_v = nn.Linear(dim, inner_dim, bias = False)
self.to_out = nn.Linear(inner_dim, dim)
edge_dim = default(edge_dim, 0)
edge_dim_in = self.group.lie_dim + edge_dim
# 局部注意力 MLP
self.loc_attn_mlp = nn.Sequential(
nn.Linear(edge_dim_in, edge_dim_in * 4),
nn.ReLU(),
nn.Linear(edge_dim_in * 4, 1),
)
# 提取邻域信息
def extract_neighborhood(self, inp, query_indices):
""" inputs: [pairs_abq (bs,n,n,d), inp_vals (bs,n,c), mask (bs,n), query_indices (bs,m)]
outputs: [neighbor_abq (bs,m,mc_samples,d), neighbor_vals (bs,m,mc_samples,c)]"""
# 从输入中获取数据
pairs_abq, inp_vals, mask, edges = inp
device = inp_vals.device
# 根据查询索引对 pairs_abq、inp_vals、mask 进行下采样
if exists(query_indices):
abq_at_query = batched_index_select(pairs_abq, query_indices, dim = 1)
mask_at_query = batched_index_select(mask, query_indices, dim = 1)
edges_at_query = batched_index_select(edges, query_indices, dim = 1) if exists(edges) else None
else:
abq_at_query = pairs_abq
mask_at_query = mask
edges_at_query = edges
mask_at_query = mask_at_query[..., None]
vals_at_query = inp_vals
dists = self.group.distance(abq_at_query)
mask_value = torch.finfo(dists.dtype).max
dists = dists.masked_fill(mask[:,None,:], mask_value)
k = min(self.mc_samples, inp_vals.shape[1])
# 从距离球中采样
bs, m, n = dists.shape
within_ball = (dists < self.r) & mask[:,None,:] & mask_at_query
noise = torch.zeros((bs, m, n), device = device).uniform_(0, 1)
valid_within_ball, nbhd_idx = torch.topk(within_ball + noise, k, dim=-1, sorted=False)
valid_within_ball = (valid_within_ball > 1)
# 获取邻域位置的 abq_pairs、values 和 mask
nbhd_abq = batched_index_select(abq_at_query, nbhd_idx, dim = 2)
nbhd_vals = batched_index_select(vals_at_query, nbhd_idx, dim = 1)
nbhd_mask = batched_index_select(mask, nbhd_idx, dim = 1)
nbhd_edges = batched_index_select(edges_at_query, nbhd_idx, dim = 2) if exists(edges) else None
# 如果处于训练阶段,���新球半径以匹配 fill_frac
if self.training:
navg = (within_ball.float()).sum(-1).sum() / mask_at_query.sum()
avg_fill = (navg / mask.sum(-1).float().mean()).cpu().item()
self.r += self.coeff * (self.fill_frac - avg_fill)
self.fill_frac_ema += .1 * (avg_fill-self.fill_frac_ema)
nbhd_mask &= valid_within_ball.bool()
return nbhd_abq, nbhd_vals, nbhd_mask, nbhd_edges, nbhd_idx
# 定义前向传播函数,接收输入数据
def forward(self, inp):
"""inputs: [pairs_abq (bs,n,n,d)], [inp_vals (bs,n,ci)]), [query_indices (bs,m)]
outputs [subsampled_abq (bs,m,m,d)], [convolved_vals (bs,m,co)]"""
# 从输入数据中抽取子样本,包括子样本的abq、值、掩码、边缘和查询索引
sub_abq, sub_vals, sub_mask, sub_edges, query_indices = self.subsample(inp, withquery = True)
# 从输入数据中提取邻域,包括邻域的abq、值、掩码、边缘和邻域索引
nbhd_abq, nbhd_vals, nbhd_mask, nbhd_edges, nbhd_indices = self.extract_neighborhood(inp, query_indices)
# 获取头数、批次大小、节点数、特征维度和设备信息
h, b, n, d, device = self.heads, *sub_vals.shape, sub_vals.device
# 将子样本的值转换为查询、键和值
q, k, v = self.to_q(sub_vals), self.to_k(nbhd_vals), self.to_v(nbhd_vals)
# 重排查询、键和值的维度
q = rearrange(q, 'b n (h d) -> b h n d', h = h)
k, v = map(lambda t: rearrange(t, 'b n m (h d) -> b h n m d', h = h), (k, v))
# 计算注意力相似度
sim = einsum('b h i d, b h i j d -> b h i j', q, k) * (q.shape[-1] ** -0.5)
# 更新边缘信息
edges = nbhd_abq
if exists(nbhd_edges):
edges = torch.cat((nbhd_abq, nbhd_edges), dim = -1)
# 通过位置注意力MLP更新位置注意力
loc_attn = self.loc_attn_mlp(edges)
loc_attn = rearrange(loc_attn, 'b i j () -> b () i j')
sim = sim + loc_attn
# 创建掩码值
mask_value = -torch.finfo(sim.dtype).max
# 使用掩码值对相似度矩阵进行掩码
sim.masked_fill_(~rearrange(nbhd_mask, 'b n m -> b () n m'), mask_value)
# 计算注意力权重
attn = sim.softmax(dim = -1)
# 计算输出值
out = einsum('b h i j, b h i j d -> b h i d', attn, v)
out = rearrange(out, 'b h n d -> b n (h d)', h = h)
# 将输出值转换为输出维度
combined = self.to_out(out)
# 返回子样本的abq、组合值、子样本掩码和子样本边缘
return sub_abq, combined, sub_mask, sub_edges
class LieSelfAttentionWrapper(nn.Module):
# 自注意力机制的包装器类
def __init__(self, dim, attn):
super().__init__()
self.dim = dim
self.attn = attn
self.net = nn.Sequential(
Pass(nn.LayerNorm(dim)), # 添加层归一化
self.attn
)
def forward(self, inp):
sub_coords, sub_values, mask, edges = self.attn.subsample(inp)
new_coords, new_values, mask, edges = self.net(inp)
new_values[..., :self.dim] += sub_values
return new_coords, new_values, mask, edges
class FeedForward(nn.Module):
# 前馈神经网络类
def __init__(self, dim, mult = 4):
super().__init__()
self.dim = dim
self.net = nn.Sequential(
Pass(nn.LayerNorm(dim)), # 添加层归一化
Pass(nn.Linear(dim, mult * dim)), # 线性变换
Pass(nn.GELU()), # GELU激活函数
Pass(nn.Linear(mult * dim, dim)), # 线性变换
)
def forward(self,inp):
sub_coords, sub_values, mask, edges = inp
new_coords, new_values, mask, edges = self.net(inp)
new_values = new_values + sub_values
return new_coords, new_values, mask, edges
# transformer class
class LieTransformer(nn.Module):
"""
[Fill] specifies the fraction of the input which is included in local neighborhood.
(can be array to specify a different value for each layer)
[nbhd] number of samples to use for Monte Carlo estimation (p)
[dim] number of input channels: 1 for MNIST, 3 for RGB images, other for non images
[ds_frac] total downsampling to perform throughout the layers of the net. In (0,1)
[num_layers] number of BottleNeck Block layers in the network
[k] channel width for the network. Can be int (same for all) or array to specify individually.
[liftsamples] number of samples to use in lifting. 1 for all groups with trivial stabilizer. Otherwise 2+
[Group] Chosen group to be equivariant to.
"""
def __init__(
self,
dim,
num_tokens = None,
num_edge_types = None,
edge_dim = None,
heads = 8,
dim_head = 64,
depth = 2,
ds_frac = 1.,
dim_out = None,
k = 1536,
nbhd = 128,
mean = True,
per_point = True,
liftsamples = 4,
fill = 1 / 4,
cache = False,
reversible = False,
**kwargs
):
super().__init__()
assert not (ds_frac < 1 and reversible), 'must not downsample if network is reversible'
dim_out = default(dim_out, dim)
self.token_emb = nn.Embedding(num_tokens, dim) if exists(num_tokens) else None
self.edge_emb = nn.Embedding(num_edge_types, edge_dim) if exists(num_edge_types) else None
group = SE3()
self.group = group
self.liftsamples = liftsamples
layers_fill = cast_tuple(fill, depth)
layers = nn.ModuleList([])
for _, layer_fill in zip(range(depth), layers_fill):
layers.append(nn.ModuleList([
LieSelfAttentionWrapper(dim, LieSelfAttention(dim, heads = heads, dim_head = dim_head, edge_dim = edge_dim, mc_samples = nbhd, ds_frac = ds_frac, group = group, fill = fill, cache = cache,**kwargs)),
FeedForward(dim)
]))
execute_class = ReversibleSequence if reversible else SequentialSequence
self.net = execute_class(layers)
self.to_logits = nn.Sequential(
Pass(nn.LayerNorm(dim)), # 添加层归一化
Pass(nn.Linear(dim, dim_out)) # 线性变换
)
self.pool = GlobalPool(mean = mean) # 全局池化
# 定义一个前向传播函数,接受特征、坐标、边缘、掩码等参数,并返回池化结果
def forward(self, feats, coors, edges = None, mask = None, return_pooled = False):
# 获取批次大小、节点数等信息
b, n, *_ = feats.shape
# 如果存在 token_emb 属性,则对特征进行处理
if exists(self.token_emb):
feats = self.token_emb(feats)
# 如果存在 edge_emb 属性,则对边缘进行处理
if exists(self.edge_emb):
# 确保 edges 参数存在
assert exists(edges), 'edges must be passed in on forward'
# 确保 edges 的形状符合要求
assert edges.shape[1] == edges.shape[2] and edges.shape[1] == n, f'edges must be of the shape ({b}, {n}, {n})'
edges = self.edge_emb(edges)
# 将坐标、特征、掩码、边缘等参数组合成元组
inps = (coors, feats, mask, edges)
# 使用 group 属性对输入进行变换
lifted_x = self.group.lift(inps, self.liftsamples)
# 将变换后的输入传入网络进行计算
out = self.net(lifted_x)
# 将输出结果转换为 logits
out = self.to_logits(out)
# 如果不需要返回池化结果,则直接返回特征
if not return_pooled:
features = out[1]
return features
# 返回池化结果
return self.pool(out)
.\lucidrains\lie-transformer-pytorch\lie_transformer_pytorch\reversible.py
# 导入 torch 库
import torch
# 导入 torch 中的神经网络模块
import torch.nn as nn
# 从 torch.autograd.function 中导入 Function 类
from torch.autograd.function import Function
# 从 torch.utils.checkpoint 中导入 get_device_states 和 set_device_states 函数
# 辅助函数
# 对元组中指定维度的元素求和
def sum_tuple(x, y, dim = 1):
x = list(x)
x[dim] += y[dim]
return tuple(x)
# 对元组中指定维度的元素求差
def subtract_tuple(x, y, dim = 1):
x = list(x)
x[dim] -= y[dim]
return tuple(x)
# 设置元组中指定维度的值
def set_tuple(x, dim, value):
x = list(x).copy()
x[dim] = value
return tuple(x)
# 对元组中指定维度的元素应用函数
def map_tuple(fn, x, dim = 1):
x = list(x)
x[dim] = fn(x[dim])
return tuple(x)
# 对元组中指定维度的元素进行分块
def chunk_tuple(fn, x, dim = 1):
x = list(x)
value = x[dim]
chunks = fn(value)
return tuple(map(lambda t: set_tuple(x, 1, t), chunks))
# 将两个元组在指定维度进行拼接
def cat_tuple(x, y, dim = 1, cat_dim = -1):
x = list(x)
y = list(y)
x[dim] = torch.cat((x[dim], y[dim]), dim = cat_dim)
return tuple(x)
# 删除元组中的元素
def del_tuple(x):
for el in x:
if el is not None:
del el
# 根据 https://pytorch.org/docs/stable/_modules/torch/utils/checkpoint.html 中的示例,实现保存和设置随机数生成器状态的类
class Deterministic(nn.Module):
def __init__(self, net):
super().__init__()
self.net = net
self.cpu_state = None
self.cuda_in_fwd = None
self.gpu_devices = None
self.gpu_states = None
def record_rng(self, *args):
self.cpu_state = torch.get_rng_state()
if torch.cuda._initialized:
self.cuda_in_fwd = True
self.gpu_devices, self.gpu_states = get_device_states(*args)
def forward(self, *args, record_rng = False, set_rng = False, **kwargs):
if record_rng:
self.record_rng(*args)
if not set_rng:
return self.net(*args, **kwargs)
rng_devices = []
if self.cuda_in_fwd:
rng_devices = self.gpu_devices
with torch.random.fork_rng(devices=rng_devices, enabled=True):
torch.set_rng_state(self.cpu_state)
if self.cuda_in_fwd:
set_device_states(self.gpu_devices, self.gpu_states)
return self.net(*args, **kwargs)
# 受 https://github.com/RobinBruegger/RevTorch/blob/master/revtorch/revtorch.py 启发,实现可逆块类
# 一旦多 GPU 确认工作正常,重构并将 PR 发回源代码
class ReversibleBlock(nn.Module):
def __init__(self, f, g):
super().__init__()
self.f = Deterministic(f)
self.g = Deterministic(g)
def forward(self, x, f_args = {}, g_args = {}):
training = self.training
x1, x2 = chunk_tuple(lambda t: torch.chunk(t, 2, dim=2), x)
y1, y2 = None, None
with torch.no_grad():
y1 = sum_tuple(self.f(x2, record_rng = training, **f_args), x1)
y2 = sum_tuple(self.g(y1, record_rng = training, **g_args), x2)
return cat_tuple(y1, y2, cat_dim = 2)
def backward_pass(self, y, dy, f_args = {}, g_args = {}):
y1, y2 = chunk_tuple(lambda t: torch.chunk(t, 2, dim=2), y)
del_tuple(y)
dy1, dy2 = torch.chunk(dy, 2, dim=2)
del dy
with torch.enable_grad():
y1[1].requires_grad = True
gy1 = self.g(y1, set_rng=True, **g_args)
torch.autograd.backward(gy1[1], dy2)
with torch.no_grad():
x2 = subtract_tuple(y2, gy1)
del_tuple(y2)
del gy1
dx1 = dy1 + y1[1].grad
del dy1
y1[1].grad = None
with torch.enable_grad():
x2[1].requires_grad = True
fx2 = self.f(x2, set_rng = True, **f_args)
torch.autograd.backward(fx2[1], dx1)
with torch.no_grad():
x1 = subtract_tuple(y1, fx2)
del fx2
del_tuple(y1)
dx2 = dy2 + x2[1].grad
del dy2
x2[1].grad = None
x2 = map_tuple(lambda t: t.detach(), x2)
x = cat_tuple(x1, x2, cat_dim = -1)
dx = torch.cat((dx1, dx2), dim=2)
return x, dx
class _ReversibleFunction(Function):
# 定义一个静态方法,用于前向传播
@staticmethod
def forward(ctx, x, blocks, kwargs):
# 将传入的参数保存在上下文中
ctx.kwargs = kwargs
# 将传入的参数重新组合
x = (kwargs.pop('coords'), x, kwargs.pop('mask'), kwargs.pop('edges'))
# 遍历每个块并进行前向传播
for block in blocks:
x = block(x, **kwargs)
# 将计算结果保存在上下文中,并将梯度分离
ctx.y = map_tuple(lambda t: t.detach(), x, dim=1)
ctx.blocks = blocks
# 返回计算结果的第二个元素
return x[1]
# 定义一个静态方法,用于反向传播
@staticmethod
def backward(ctx, dy):
# 从上下文中获取保存的数据
y = ctx.y
kwargs = ctx.kwargs
# 反向遍历每个块并进行反向传播
for block in ctx.blocks[::-1]:
y, dy = block.backward_pass(y, dy, **kwargs)
# 返回计算结果的梯度
return dy, None, None
class SequentialSequence(nn.Module):
# 定义一个顺序执行的序列模块
def __init__(self, blocks):
# 初始化函数,接受一个包含多个块的列表作为参数
super().__init__()
# 调用父类的初始化函数
self.blocks = blocks
# 将传入的块列表保存在当前对象的属性中
def forward(self, x):
# 前向传播函数,接受输入参数 x
for (f, g) in self.blocks:
# 遍历块列表中的每个块,每个块包含两个函数 f 和 g
x = sum_tuple(f(x), x, dim = 1)
# 将 f 函数作用在输入 x 上,然后与 x 求和,指定维度为 1
x = sum_tuple(g(x), x, dim = 1)
# 将 g 函数作用在上一步的结果 x 上,然后与 x 求和,指定维度为 1
return x
# 返回最终结果 x
class ReversibleSequence(nn.Module):
# 定义一个可逆执行的序列模块
def __init__(self, blocks):
# 初始化函数,接受一个包含多个块的列表作为参数
super().__init__()
# 调用父类的初始化函数
self.blocks = nn.ModuleList([ReversibleBlock(f, g) for (f, g) in blocks])
# 将传入的块列表中的每个块转换为 ReversibleBlock 对象,并保存在当前对象的属性中
def forward(self, x, **kwargs):
# 前向传播函数,接受输入参数 x 和关键字参数 kwargs
x = map_tuple(lambda t: torch.cat((t, t), dim = -1), x)
# 对输入 x 中的每个元素应用 lambda 函数,将其在最后一个维度上进行拼接
blocks = self.blocks
# 将当前对象的块列表保存在变量 blocks 中
coords, values, mask, edges = x
# 将输入 x 拆分为 coords、values、mask 和 edges 四部分
kwargs = {'coords': coords, 'mask': mask, 'edges': edges, **kwargs}
# 将 coords、mask、edges 和 kwargs 合并为一个字典
x = _ReversibleFunction.apply(values, blocks, kwargs)
# 调用自定义的 _ReversibleFunction 类的 apply 方法,传入 values、blocks 和 kwargs,得到结果 x
x = (coords, x, mask, edges)
# 将 x 重新组合为一个元组
return map_tuple(lambda t: sum(t.chunk(2, dim = -1)) * 0.5, x)
# 对 x 中的每个元素应用 lambda 函数,将其在最后一个维度上进行拆分并求和,然后乘以 0.5
.\lucidrains\lie-transformer-pytorch\lie_transformer_pytorch\se3.py
# 从 math 模块中导入 pi 常数
from math import pi
# 导入 torch 模块
import torch
# 从 functools 模块中导入 wraps 装饰器
from functools import wraps
# 从 torch 模块中导入 acos, atan2, cos, sin 函数
from torch import acos, atan2, cos, sin
# 从 einops 模块中导入 rearrange, repeat 函数
# 常量
THRES = 7e-2
# 辅助函数
# 判断变量是否存在
def exists(val):
return val is not None
# 返回张量的设备和数据类型
def to(t):
return {'device': t.device, 'dtype': t.dtype}
# Taylor 展开函数
def taylor(thres):
def outer(fn):
@wraps(fn)
def inner(x):
usetaylor = x.abs() < THRES
taylor_expanded, full = fn(x, x * x)
return torch.where(usetaylor, taylor_expanded, full)
return inner
return outer
# 用于解析指数映射的辅助函数。在 x=0 附近使用 Taylor 展开
# 参考 http://ethaneade.com/lie_groups.pdf 进行推导
# sinc 函数的 Taylor 展开
@taylor(THRES)
def sinc(x, x2):
""" sin(x)/x """
texpand = 1-x2/6*(1-x2/20*(1-x2/42))
full = sin(x) / x
return texpand, full
# sincc 函数的 Taylor 展开
@taylor(THRES)
def sincc(x, x2):
""" (1-sinc(x))/x^2"""
texpand = 1/6*(1-x2/20*(1-x2/42*(1-x2/72)))
full = (x-sin(x)) / x**3
return texpand, full
# cosc 函数的 Taylor 展开
@taylor(THRES)
def cosc(x, x2):
""" (1-cos(x))/x^2"""
texpand = 1/2*(1-x2/12*(1-x2/30*(1-x2/56)))
full = (1-cos(x)) / x2
return texpand, full
# coscc 函数的 Taylor 展开
@taylor(THRES)
def coscc(x, x2):
texpand = 1/12*(1+x2/60*(1+x2/42*(1+x2/40)))
costerm = (2*(1-cos(x))).clamp(min=1e-6)
full = (1-x*sin(x)/costerm) / x2
return texpand, full
# sinc_inv 函数的 Taylor 展开
@taylor(THRES)
def sinc_inv(x, _):
texpand = 1+(1/6)*x**2 +(7/360)*x**4
full = x / sin(x)
assert not torch.any(torch.isinf(texpand)|torch.isnan(texpand)),'sincinv texpand inf'+torch.any(torch.isinf(texpand))
return texpand, full
# Lie 群作用于 R3
# R3 上的 Hodge 星算子
def cross_matrix(k):
"""Application of hodge star on R3, mapping Λ^1 R3 -> Λ^2 R3"""
K = torch.zeros(*k.shape[:-1], 3, 3, **to(k))
K[...,0,1] = -k[...,2]
K[...,0,2] = k[...,1]
K[...,1,0] = k[...,2]
K[...,1,2] = -k[...,0]
K[...,2,0] = -k[...,1]
K[...,2,1] = k[...,0]
return K
# 逆 Hodge 星算子
def uncross_matrix(K):
"""Application of hodge star on R3, mapping Λ^2 R3 -> Λ^1 R3"""
k = torch.zeros(*K.shape[:-1], **to(K))
k[...,0] = (K[...,2,1] - K[...,1,2])/2
k[...,1] = (K[...,0,2] - K[...,2,0])/2
k[...,2] = (K[...,1,0] - K[...,0,1])/2
return k
# SO3 类
class SO3:
lie_dim = 3
rep_dim = 3
q_dim = 1
def __init__(self, alpha = .2):
super().__init__()
self.alpha = alpha
# 计算指数映射
def exp(self,w):
""" Computes (matrix) exponential Lie algebra elements (in a given basis).
ie out = exp(\sum_i a_i A_i) where A_i are the exponential generators of G.
Input: [a (*,lie_dim)] where * is arbitrarily shaped
Output: [exp(a) (*,rep_dim,rep_dim)] returns the matrix for each."""
""" Rodriguez's formula, assuming shape (*,3)
where components 1,2,3 are the generators for xrot,yrot,zrot"""
theta = w.norm(dim=-1)[..., None, None]
K = cross_matrix(w)
I = torch.eye(3, **to(K))
Rs = I + K * sinc(theta) + (K @ K) * cosc(theta)
return Rs
# 计算对数映射
def log(self,R):
""" Computes components in terms of generators rx,ry,rz. Shape (*,3,3)"""
""" Computes (matrix) logarithm for collection of matrices and converts to Lie algebra basis.
Input [u (*,rep_dim,rep_dim)]
Output [coeffs of log(u) in basis (*,d)] """
trR = R[..., 0, 0] + R[..., 1, 1] + R[..., 2, 2]
costheta = ((trR-1) / 2).clamp(max=1, min=-1).unsqueeze(-1)
theta = acos(costheta)
logR = uncross_matrix(R) * sinc_inv(theta)
return logR
# 计算逆元素
def inv(self,g):
""" We can compute the inverse of elements g (*,rep_dim,rep_dim) as exp(-log(g))"""
return self.exp(-self.log(g))
def elems2pairs(self,a):
""" 计算输入中沿着 n 维度的所有 a b 对的 log(e^-b e^a)。
输入: [a (bs,n,d)] 输出: [pairs_ab (bs,n,n,d)] """
# 计算 e^-a 的逆
vinv = self.exp(-a.unsqueeze(-3))
# 计算 e^a
u = self.exp(a.unsqueeze(-2))
# 计算 log(e^-b e^a)
return self.log(vinv@u) # ((bs,1,n,d) -> (bs,1,n,r,r))@((bs,n,1,d) -> (bs,n,1,r,r))
def lift(self, x, nsamples, **kwargs):
""" 假设 p 的形状为 (*,n,2),vals 的形状为 (*,n,c),mask 的形状为 (*,n)
返回形状为 [(*,n*nsamples,lie_dim),(*,n*nsamples,c)] 的 (a,v) """
p, v, m, e = x
# 将 p 展开为 (bs,n*ns,d) 和 (bs,n*ns,qd)
expanded_a = self.lifted_elems(p,nsamples,**kwargs)
nsamples = expanded_a.shape[-2]//m.shape[-1]
# 将 v 和 mask 像 q 一样展开
expanded_v = repeat(v, 'b n c -> b (n m) c', m = nsamples) # (bs,n,c) -> (bs,n,1,c) -> (bs,n,ns,c) -> (bs,n*ns,c)
expanded_mask = repeat(m, 'b n -> b (n m)', m = nsamples) # (bs,n) -> (bs,n,ns) -> (bs,n*ns)
expanded_e = repeat(e, 'b n1 n2 c -> b (n1 m1) (n2 m2) c', m1 = nsamples, m2 = nsamples) if exists(e) else None
# 从 elems 转换为 pairs
paired_a = self.elems2pairs(expanded_a) #(bs,n*ns,d) -> (bs,n*ns,n*ns,d)
embedded_locations = paired_a
return (embedded_locations,expanded_v,expanded_mask, expanded_e)
class SE3(SO3):
# 定义 SE3 类,继承自 SO3 类
lie_dim = 6
# 定义李代数维度为 6
rep_dim = 4
# 定义表示维度为 4
q_dim = 0
# 定义 q 维度为 0
def __init__(self, alpha=.2, per_point=True):
# 初始化函数,接受 alpha 和 per_point 两个参数
super().__init__()
# 调用父类的初始化函数
self.alpha = alpha
# 设置对象的 alpha 属性为传入的 alpha 值
self.per_point = per_point
# 设置对象的 per_point 属性为传入的 per_point 值
def exp(self,w):
# 定义 exp 函数,接受参数 w
dd_kwargs = to(w)
# 将 w 转换为 dd_kwargs
theta = w[...,:3].norm(dim=-1)[...,None,None]
# 计算 w 的前三个元素的范数,并扩展维度
K = cross_matrix(w[...,:3])
# 计算 w 的前三个元素的叉乘矩阵
R = super().exp(w[...,:3])
# 调用父类的 exp 函数,计算 w 的前三个元素的指数映射
I = torch.eye(3, **dd_kwargs)
# 创建 3x3 的单位矩阵
V = I + cosc(theta)*K + sincc(theta)*(K@K)
# 计算 V 矩阵
U = torch.zeros(*w.shape[:-1],4,4, **dd_kwargs)
# 创建全零的 4x4 矩阵
U[...,:3,:3] = R
# 将 R 赋值给 U 的前三行前三列
U[...,:3,3] = (V@w[...,3:].unsqueeze(-1)).squeeze(-1)
# 计算并赋值 U 的前三行第四列
U[...,3,3] = 1
# 设置 U 的第四行第四列为 1
return U
# 返回 U 矩阵
def log(self,U):
# 定义 log 函数,接受参数 U
w = super().log(U[..., :3, :3])
# 调用父类的 log 函数,计算 U 的前三行前三列的对数映射
I = torch.eye(3, **to(w))
# 创建 3x3 的单位矩阵
K = cross_matrix(w[..., :3])
# 计算 w 的前三个元素的叉乘矩阵
theta = w.norm(dim=-1)[..., None, None]#%(2*pi)
# 计算 w 的范数,并扩展维度
cosccc = coscc(theta)
# 计算 coscc(theta)
Vinv = I - K/2 + cosccc*(K@K)
# 计算 Vinv 矩阵
u = (Vinv @ U[..., :3, 3].unsqueeze(-1)).squeeze(-1)
# 计算 u 向量
return torch.cat([w, u], dim=-1)
# 返回拼接后的 w 和 u 向量
def lifted_elems(self,pt,nsamples):
""" pt (bs,n,D) mask (bs,n), per_point specifies whether to
use a different group element per atom in the molecule"""
# 返回 farthest_lift 函数的结果
# same lifts for each point right now
bs,n = pt.shape[:2]
# 获取 pt 的形状
dd_kwargs = to(pt)
# 将 pt 转换为 dd_kwargs
q = torch.randn(bs, (n if self.per_point else 1), nsamples, 4, **dd_kwargs)
# 生成服从标准正态分布的随机数
q /= q.norm(dim=-1).unsqueeze(-1)
# 对 q 进行归一化
theta_2 = atan2(q[..., 1:].norm(dim=-1),q[..., 0])[..., None]
# 计算角度 theta_2
so3_elem = 2 * sinc_inv(theta_2) * q[...,1:]
# 计算 so3_elem
se3_elem = torch.cat([so3_elem, torch.zeros_like(so3_elem)], dim=-1)
# 拼接得到 se3_elem
R = self.exp(se3_elem)
# 计算 se3_elem 的指数映射
T = torch.zeros(bs, n, nsamples, 4, 4, **dd_kwargs)
# 创建全零的 4x4 矩阵
T[..., :, :] = torch.eye(4, **dd_kwargs)
# 将单位矩阵赋值给 T
T[..., :3, 3] = pt[..., None, :]
# 将 pt 赋值给 T 的前三行第四列
a = self.log(T @ R)
# 计算 T @ R 的对数映射
return a.reshape(bs, n * nsamples, 6)
# 返回重塑后的结果
def distance(self,abq_pairs):
# 定义 distance 函数,接受参数 abq_pairs
dist_rot = abq_pairs[...,:3].norm(dim=-1)
# 计算旋转部分的距离
dist_trans = abq_pairs[...,3:].norm(dim=-1)
# 计算平移部分的距离
return dist_rot * self.alpha + (1-self.alpha) * dist_trans
# 返回旋转部分距禂乘以 alpha 加上平移部分距离乘以 (1-alpha) 的结果
.\lucidrains\lie-transformer-pytorch\lie_transformer_pytorch\__init__.py
# 从lie_transformer_pytorch模块中导入LieTransformer类
from lie_transformer_pytorch.lie_transformer_pytorch import LieTransformer
Lie Transformer - Pytorch
Implementation of Lie Transformer, Equivariant Self-Attention, in Pytorch. Only the SE3 version will be present in this repository, as it may be needed for Alphafold2 replication.
Install
$ pip install lie-transformer-pytorch
Usage
import torch
from lie_transformer_pytorch import LieTransformer
model = LieTransformer(
dim = 512,
depth = 2,
heads = 8,
dim_head = 64,
liftsamples = 4
)
coors = torch.randn(1, 64, 3)
features = torch.randn(1, 64, 512)
mask = torch.ones(1, 64).bool()
out = model(features, coors, mask = mask) # (1, 256, 512) <- 256 = (seq len * liftsamples)
Allowing Lie Transformer take care of embedding the features, just specify the number of unique tokens (node types).
import torch
from lie_transformer_pytorch import LieTransformer
model = LieTransformer(
num_tokens = 28, # say 28 different types of atoms
dim = 512,
depth = 2,
heads = 8,
dim_head = 64,
liftsamples = 4
)
atoms = torch.randint(0, 28, (1, 64))
coors = torch.randn(1, 64, 3)
mask = torch.ones(1, 64).bool()
out = model(atoms, coors, mask = mask) # (1, 256, 512) <- 256 = (seq len * liftsamples)
Although it was not in the paper, I decided to allow for passing in edge information as well (bond types). The edge information will be embedded by the dimension specified, concatted with the location, and passed through the MLP before summed with the attention matrix.
Simply set two more keyword arguments on initialization of the transformer, and then pass in the specific bond types as shape b x seq x seq
.
import torch
from lie_transformer_pytorch import LieTransformer
model = LieTransformer(
num_tokens = 28, # say 28 different types of atoms
num_edge_types = 4, # number of different edge types
edge_dim = 16, # dimension of edges
dim = 512,
depth = 2,
heads = 8,
dim_head = 64,
liftsamples = 4
)
atoms = torch.randint(0, 28, (1, 64))
bonds = torch.randint(0, 4, (1, 64, 64))
coors = torch.randn(1, 64, 3)
mask = torch.ones(1, 64).bool()
out = model(atoms, coors, edges = bonds, mask = mask) # (1, 256, 512) <- 256 = (seq len * liftsamples)
Credit
This repository is largely adapted from LieConv, cited below
Citations
@misc{hutchinson2020lietransformer,
title = {LieTransformer: Equivariant self-attention for Lie Groups},
author = {Michael Hutchinson and Charline Le Lan and Sheheryar Zaidi and Emilien Dupont and Yee Whye Teh and Hyunjik Kim},
year = {2020},
eprint = {2012.10885},
archivePrefix = {arXiv},
primaryClass = {cs.LG}
}
@misc{finzi2020generalizing,
title = {Generalizing Convolutional Neural Networks for Equivariance to Lie Groups on Arbitrary Continuous Data},
author = {Marc Finzi and Samuel Stanton and Pavel Izmailov and Andrew Gordon Wilson},
year = {2020},
eprint = {2002.12880},
archivePrefix = {arXiv},
primaryClass = {stat.ML}
}
.\lucidrains\lie-transformer-pytorch\setup.py
# 导入设置和查找包的函数
from setuptools import setup, find_packages
# 设置包的元数据
setup(
name = 'lie-transformer-pytorch', # 包的名称
packages = find_packages(), # 查找所有包
version = '0.0.17', # 版本号
license='MIT', # 许可证
description = 'Lie Transformer - Pytorch', # 描述
author = 'Phil Wang', # 作者
author_email = 'lucidrains@gmail.com', # 作者邮箱
url = 'https://github.com/lucidrains/lie-transformer-pytorch', # 项目链接
keywords = [ # 关键词列表
'artificial intelligence',
'attention mechanism',
'transformers',
'equivariance',
'lifting',
'lie groups'
],
install_requires=[ # 安装依赖
'torch>=1.6',
'einops>=0.3'
],
setup_requires=[ # 设置需要的依赖
'pytest-runner',
],
tests_require=[ # 测试需要的依赖
'pytest'
],
classifiers=[ # 分类器
'Development Status :: 4 - Beta',
'Intended Audience :: Developers',
'Topic :: Scientific/Engineering :: Artificial Intelligence',
'License :: OSI Approved :: MIT License',
'Programming Language :: Python :: 3.6',
],
)
.\lucidrains\lie-transformer-pytorch\tests.py
# 导入 torch 库
import torch
# 从 lie_transformer_pytorch 库中导入 LieTransformer 类
from lie_transformer_pytorch import LieTransformer
# 定义测试 LieTransformer 类的函数
def test_transformer():
# 创建 LieTransformer 模型对象,设置维度为 512,深度为 1
model = LieTransformer(
dim = 512,
depth = 1
)
# 生成一个形状为 (1, 64, 512) 的随机张量 feats
feats = torch.randn(1, 64, 512)
# 生成一个形状为 (1, 64, 3) 的随机张量 coors
coors = torch.randn(1, 64, 3)
# 生成一个形状为 (1, 64) 的全为 True 的布尔张量 mask
mask = torch.ones(1, 64).bool()
# 使用 LieTransformer 模型处理 feats, coors 和 mask,得到输出 out
out = model(feats, coors, mask = mask)
# 断言输出 out 的形状为 (1, 256, 512),如果不是则输出 'transformer runs'
assert out.shape == (1, 256, 512), 'transformer runs'
.\lucidrains\lightweight-gan\lightweight_gan\cli.py
# 导入所需的库
import os
import fire
import random
from retry.api import retry_call
from tqdm import tqdm
from datetime import datetime
from functools import wraps
from lightweight_gan import Trainer, NanException
from lightweight_gan.diff_augment_test import DiffAugmentTest
import torch
import torch.multiprocessing as mp
import torch.distributed as dist
import numpy as np
# 检查值是否存在
def exists(val):
return val is not None
# 如果值存在则返回该值,否则返回默认值
def default(val, d):
return val if exists(val) else d
# 将元素转换为列表
def cast_list(el):
return el if isinstance(el, list) else [el]
# 生成带时间戳的文件名
def timestamped_filename(prefix = 'generated-'):
now = datetime.now()
timestamp = now.strftime("%m-%d-%Y_%H-%M-%S")
return f'{prefix}{timestamp}'
# 设置随机种子
def set_seed(seed):
torch.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
np.random.seed(seed)
random.seed(seed)
# 运行训练过程
def run_training(rank, world_size, model_args, data, load_from, new, num_train_steps, name, seed, use_aim, aim_repo, aim_run_hash):
is_main = rank == 0
is_ddp = world_size > 1
if is_ddp:
set_seed(seed)
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12355'
dist.init_process_group('nccl', rank=rank, world_size=world_size)
print(f"{rank + 1}/{world_size} process initialized.")
model_args.update(
is_ddp = is_ddp,
rank = rank,
world_size = world_size
)
model = Trainer(**model_args, hparams=model_args, use_aim=use_aim, aim_repo=aim_repo, aim_run_hash=aim_run_hash)
if not new:
model.load(load_from)
else:
model.clear()
model.set_data_src(data)
progress_bar = tqdm(initial = model.steps, total = num_train_steps, mininterval=10., desc=f'{name}<{data}>')
while model.steps < num_train_steps:
retry_call(model.train, tries=3, exceptions=NanException)
progress_bar.n = model.steps
progress_bar.refresh()
if is_main and model.steps % 50 == 0:
model.print_log()
model.save(model.checkpoint_num)
if is_ddp:
dist.destroy_process_group()
# 从文件夹中训练模型
def train_from_folder(
data = './data',
results_dir = './results',
models_dir = './models',
name = 'default',
new = False,
load_from = -1,
image_size = 256,
optimizer = 'adam',
fmap_max = 512,
transparent = False,
greyscale = False,
batch_size = 10,
gradient_accumulate_every = 4,
num_train_steps = 150000,
learning_rate = 2e-4,
save_every = 1000,
evaluate_every = 1000,
generate = False,
generate_types = ['default', 'ema'],
generate_interpolation = False,
aug_test = False,
aug_prob=None,
aug_types=['cutout', 'translation'],
dataset_aug_prob=0.,
attn_res_layers = [32],
freq_chan_attn = False,
disc_output_size = 1,
dual_contrast_loss = False,
antialias = False,
interpolation_num_steps = 100,
save_frames = False,
num_image_tiles = None,
num_workers = None,
multi_gpus = False,
calculate_fid_every = None,
calculate_fid_num_images = 12800,
clear_fid_cache = False,
seed = 42,
amp = False,
show_progress = False,
use_aim = False,
aim_repo = None,
aim_run_hash = None,
load_strict = True
):
num_image_tiles = default(num_image_tiles, 4 if image_size > 512 else 8)
# 定义模型参数字典
model_args = dict(
name = name,
results_dir = results_dir,
models_dir = models_dir,
batch_size = batch_size,
gradient_accumulate_every = gradient_accumulate_every,
attn_res_layers = cast_list(attn_res_layers),
freq_chan_attn = freq_chan_attn,
disc_output_size = disc_output_size,
dual_contrast_loss = dual_contrast_loss,
antialias = antialias,
image_size = image_size,
num_image_tiles = num_image_tiles,
optimizer = optimizer,
num_workers = num_workers,
fmap_max = fmap_max,
transparent = transparent,
greyscale = greyscale,
lr = learning_rate,
save_every = save_every,
evaluate_every = evaluate_every,
aug_prob = aug_prob,
aug_types = cast_list(aug_types),
dataset_aug_prob = dataset_aug_prob,
calculate_fid_every = calculate_fid_every,
calculate_fid_num_images = calculate_fid_num_images,
clear_fid_cache = clear_fid_cache,
amp = amp,
load_strict = load_strict
)
# 如果需要生成图片
if generate:
# 创建训练器对象,传入模型参数和是否使用 AIM
model = Trainer(**model_args, use_aim = use_aim)
# 加载模型
model.load(load_from)
# 生成样本名称
samples_name = timestamped_filename()
# 获取当前训练步数
checkpoint = model.checkpoint_num
# 生成图片
dir_result = model.generate(samples_name, num_image_tiles, checkpoint, generate_types)
# 打印生成的样本图片路径
print(f'sample images generated at {dir_result}')
return
# 如果需要生成插值图片
if generate_interpolation:
# 创建训练器对象,传入模型参数和是否使用 AIM
model = Trainer(**model_args, use_aim = use_aim)
# 加载模型
model.load(load_from)
# 生成样本名称
samples_name = timestamped_filename()
# 生成插值图片
model.generate_interpolation(samples_name, num_image_tiles, num_steps = interpolation_num_steps, save_frames = save_frames)
# 打印生成的插值图片路径
print(f'interpolation generated at {results_dir}/{name}/{samples_name}')
return
# 如果需要展示训练进度
if show_progress:
# 创建训练器对象,传入模型参数和是否使用 AIM
model = Trainer(**model_args, use_aim = use_aim)
# 展示训练进度
model.show_progress(num_images=num_image_tiles, types=generate_types)
return
# 如果需要进行数据增强测试
if aug_test:
# 进行数据增强测试
DiffAugmentTest(data=data, image_size=image_size, batch_size=batch_size, types=aug_types, nrow=num_image_tiles)
return
# 获取当前可用的 GPU 数量
world_size = torch.cuda.device_count()
# 如果只有一个 GPU 或者不使用多 GPU 训练
if world_size == 1 or not multi_gpus:
# 单 GPU 训练
run_training(0, 1, model_args, data, load_from, new, num_train_steps, name, seed, use_aim, aim_repo, aim_run_hash)
return
# 使用多 GPU 训练
mp.spawn(run_training,
args=(world_size, model_args, data, load_from, new, num_train_steps, name, seed, use_aim, aim_repo, aim_run_hash,),
nprocs=world_size,
join=True)
# 定义主函数
def main():
# 使用 Fire 库将 train_from_folder 函数转换为命令行接口
fire.Fire(train_from_folder)