.\lucidrains\CoCa-pytorch\coca_pytorch\__init__.py
# 从 coca_pytorch 模块中导入 CoCa 类
from coca_pytorch.coca_pytorch import CoCa
CoCa - Pytorch
Implementation of CoCa, Contrastive Captioners are Image-Text Foundation Models, in Pytorch. They were able to elegantly fit in contrastive learning to a conventional encoder / decoder (image to text) transformer, achieving SOTA 91.0% top-1 accuracy on ImageNet with a finetuned encoder.
This repository also chooses to adopt the specific transformer architecture from PaLM, for both the unimodal and multimodal transformers as well as the cross attention blocks (parallel SwiGLU feedforwards)
Update: CoCa has been trained by the good folks over at OpenClip
Install
$ pip install coca-pytorch
Usage
First install the vit-pytorch
for the image encoder, which needs to be pretrained
$ pip install vit-pytorch>=0.40.2
Then
import torch
# import vision transformer
from vit_pytorch.simple_vit_with_patch_dropout import SimpleViT
from vit_pytorch.extractor import Extractor
vit = SimpleViT(
image_size = 256,
patch_size = 32,
num_classes = 1000,
dim = 1024,
depth = 6,
heads = 16,
mlp_dim = 2048,
patch_dropout = 0.5 # https://arxiv.org/abs/2212.00794
)
vit = Extractor(vit, return_embeddings_only = True, detach = False)
# extractor will enable it so the vision transformer returns its embeddings
# import CoCa and instantiate it
from coca_pytorch.coca_pytorch import CoCa
coca = CoCa(
dim = 512, # model dimension
img_encoder = vit, # vision transformer - image encoder, returning image embeddings as (batch, seq, dim)
image_dim = 1024, # image embedding dimension, if not the same as model dimensions
num_tokens = 20000, # number of text tokens
unimodal_depth = 6, # depth of the unimodal transformer
multimodal_depth = 6, # depth of the multimodal transformer
dim_head = 64, # dimension per attention head
heads = 8, # number of attention heads
caption_loss_weight = 1., # weight on the autoregressive caption loss
contrastive_loss_weight = 1., # weight on the contrastive loss between image and text CLS embeddings
).cuda()
# mock text and images
text = torch.randint(0, 20000, (4, 512)).cuda()
images = torch.randn(4, 3, 256, 256).cuda()
# train by giving CoCa your text and images with `return_loss = True`
loss = coca(
text = text,
images = images,
return_loss = True # set this to True to get the full caption + contrastive loss
)
loss.backward()
# do the above for as much text and images...
# then you can get the caption logits as so
logits = coca(
text = text,
images = images
) # (4, 512, 20000)
# and the CLIP-like text and image embeddings as
text_embeds, image_embeds = coca(
text = text,
images = images,
return_embeddings = True
) # (4, 512), (4, 512)
Citations
@inproceedings{Yu2022CoCaCC,
title = {CoCa: Contrastive Captioners are Image-Text Foundation Models},
author = {Jiahui Yu and Zirui Wang and Vijay Vasudevan and Legg Yeung and Mojtaba Seyedhosseini and Yonghui Wu},
year = {2022}
}
@inproceedings{Chowdhery2022PaLMSL,
title = {PaLM: Scaling Language Modeling with Pathways},
author = {Aakanksha Chowdhery and Sharan Narang and Jacob Devlin and Maarten Bosma and Gaurav Mishra and Adam Roberts and Paul Barham and Hyung Won Chung and Charles Sutton and Sebastian Gehrmann and Parker Schuh and Kensen Shi and Sasha Tsvyashchenko and Joshua Maynez and Abhishek Rao and Parker Barnes and Yi Tay and Noam M. Shazeer and Vinodkumar Prabhakaran and Emily Reif and Nan Du and Benton C. Hutchinson and Reiner Pope and James Bradbury and Jacob Austin and Michael Isard and Guy Gur-Ari and Pengcheng Yin and Toju Duke and Anselm Levskaya and Sanjay Ghemawat and Sunipa Dev and Henryk Michalewski and Xavier Garc{\'i}a and Vedant Misra and Kevin Robinson and Liam Fedus and Denny Zhou and Daphne Ippolito and David Luan and Hyeontaek Lim and Barret Zoph and Alexander Spiridonov and Ryan Sepassi and David Dohan and Shivani Agrawal and Mark Omernick and Andrew M. Dai and Thanumalayan Sankaranarayana Pillai and Marie Pellat and Aitor Lewkowycz and Erica Oliveira Moreira and Rewon Child and Oleksandr Polozov and Katherine Lee and Zongwei Zhou and Xuezhi Wang and Brennan Saeta and Mark Diaz and Orhan Firat and Michele Catasta and Jason Wei and Kathleen S. Meier-Hellstern and Douglas Eck and Jeff Dean and Slav Petrov and Noah Fiedel},
year = {2022}
}
.\lucidrains\CoCa-pytorch\setup.py
# 导入设置和查找包的函数
from setuptools import setup, find_packages
# 设置包的元数据
setup(
name = 'CoCa-pytorch', # 包的名称
packages = find_packages(exclude=[]), # 查找所有包
version = '0.1.0', # 版本号
license='MIT', # 许可证
description = 'CoCa, Contrastive Captioners are Image-Text Foundation Models - Pytorch', # 描述
author = 'Phil Wang', # 作者
author_email = 'lucidrains@gmail.com', # 作者邮箱
long_description_content_type = 'text/markdown', # 长描述内容类型
url = 'https://github.com/lucidrains/CoCa-pytorch', # 项目链接
keywords = [ # 关键词列表
'artificial intelligence',
'deep learning',
'transformers',
'attention mechanism',
'contrastive learning',
'multimodal'
],
install_requires=[ # 安装依赖
'einops>=0.4',
'torch>=1.6',
],
classifiers=[ # 分类器
'Development Status :: 4 - Beta',
'Intended Audience :: Developers',
'Topic :: Scientific/Engineering :: Artificial Intelligence',
'License :: OSI Approved :: MIT License',
'Programming Language :: Python :: 3.6',
],
)
.\lucidrains\coco-lm-pytorch\coco_lm_pytorch\coco_lm_pytorch.py
# 导入数学库
import math
# 从 functools 库中导入 reduce 函数
from functools import reduce
# 导入 torch 库
import torch
# 从 torch 库中导入 nn 模块和 einsum 函数
from torch import nn, einsum
# 从 torch.nn.functional 库中导入 F 模块
import torch.nn.functional as F
# 辅助函数
# 计算输入张量的对数,加上一个很小的值 eps 防止出现对数值为负数的情况
def log(t, eps=1e-9):
return torch.log(t + eps)
# 对输入张量进行 L2 归一化
def norm(t):
return F.normalize(t, p = 2, dim = -1)
# 生成 Gumbel 噪声
def gumbel_noise(t):
noise = torch.zeros_like(t).uniform_(0, 1)
return -log(-log(noise))
# 使用 Gumbel 噪声对输入张量进行采样
def gumbel_sample(t, temperature = 1.):
return ((t / temperature) + gumbel_noise(t)).argmax(dim=-1)
# 根据概率生成掩码
def prob_mask_like(t, prob):
return torch.zeros_like(t).float().uniform_(0, 1) < prob
# 根据给定的标记 ID 列表生成掩码
def mask_with_tokens(t, token_ids):
init_no_mask = torch.full_like(t, False, dtype=torch.bool)
mask = reduce(lambda acc, el: acc | (t == el), token_ids, init_no_mask)
return mask
# 根据概率生成子集掩码
def get_mask_subset_with_prob(mask, prob):
batch, seq_len, device = *mask.shape, mask.device
max_masked = math.ceil(prob * seq_len)
num_tokens = mask.sum(dim=-1, keepdim=True)
mask_excess = (mask.cumsum(dim=-1) > (num_tokens * prob).ceil())
mask_excess = mask_excess[:, :max_masked]
rand = torch.rand((batch, seq_len), device=device).masked_fill(~mask, -1e9)
_, sampled_indices = rand.topk(max_masked, dim=-1)
sampled_indices = (sampled_indices + 1).masked_fill_(mask_excess, 0)
new_mask = torch.zeros((batch, seq_len + 1), device=device)
new_mask.scatter_(-1, sampled_indices, 1)
return new_mask[:, 1:].bool()
# 隐藏层提取器类,用于在语言模型中神奇地添加适配器以进行预训练
class HiddenLayerExtractor(nn.Module):
def __init__(self, net, layer = -2):
super().__init__()
self.net = net
self.layer = layer
self.hidden = None
self.hook_registered = False
def _find_layer(self):
if type(self.layer) == str:
modules = dict([*self.net.named_modules()])
return modules.get(self.layer, None)
elif type(self.layer) == int:
children = [*self.net.children()]
return children[self.layer]
return None
def _hook(self, _, __, output):
self.hidden = output
def _register_hook(self):
layer = self._find_layer()
assert layer is not None, f'hidden layer ({self.layer}) not found'
handle = layer.register_forward_hook(self._hook)
self.hook_registered = True
def forward(self, x):
if self.layer == -1:
return self.net(x)
if not self.hook_registered:
self._register_hook()
_ = self.net(x)
hidden = self.hidden
self.hidden = None
assert hidden is not None, f'hidden layer {self.layer} never emitted an output'
return hidden
# 主要的 Electra 类
class COCO(nn.Module):
def __init__(
self,
generator,
discriminator,
*,
discr_dim,
num_tokens = None,
discr_layer = -1,
mask_prob = 0.15,
replace_prob = 0.85,
random_token_prob = 0.,
pad_token_id = 0,
cls_token_id = 1,
mask_token_id = 2,
mask_ignore_token_ids = [],
disc_weight = 50.,
gen_weight = 1.,
cl_weight = 1.,
temperature = 1.,
crop_percentage = 0.5
):
# 调用父类的构造函数
super().__init__()
# 初始化生成器和鉴别器
self.generator = generator
self.discriminator = discriminator
# 提取鉴别器的隐藏层特征
self.discriminator = HiddenLayerExtractor(discriminator, layer = discr_layer)
# 将鉴别器的维度映射到1维
self.to_correction_logits = nn.Linear(discr_dim, 1)
# MLM相关的概率
self.mask_prob = mask_prob
self.replace_prob = replace_prob
# token的数量
self.num_tokens = num_tokens
self.random_token_prob = random_token_prob
# token的id
self.cls_token_id = cls_token_id
self.pad_token_id = pad_token_id
self.mask_token_id = mask_token_id
self.mask_ignore_token_ids = set([*mask_ignore_token_ids, pad_token_id, cls_token_id])
# 采样温度
self.temperature = temperature
# 损失权重
self.disc_weight = disc_weight
self.gen_weight = gen_weight
self.cl_weight = cl_weight
# Contrastive Loss的温度参数
self.cl_temperature = nn.Parameter(torch.tensor(1.))
# 裁剪百分比
self.crop_percentage = crop_percentage
.\lucidrains\coco-lm-pytorch\coco_lm_pytorch\__init__.py
# 从 coco_lm_pytorch.coco_lm_pytorch 模块中导入 COCO 类
from coco_lm_pytorch.coco_lm_pytorch import COCO
COCO LM Pretraining (wip)
Implementation of COCO-LM, Correcting and Contrasting Text Sequences for Language Model Pretraining, in Pytorch. They were able to make contrastive learning work in a self-supervised manner for language model pretraining. Seems like a solid successor to Electra.
Install
$ pip install coco-lm-pytorch
Usage
An example using the x-transformers
library
$ pip install x-transformers
Then
import torch
from coco_lm_pytorch import COCO
# (1) instantiate the generator and discriminator, making sure that the generator is roughly a quarter to a half of the size of the discriminator
from x_transformers import TransformerWrapper, Encoder
generator = TransformerWrapper(
num_tokens = 20000,
emb_dim = 128,
max_seq_len = 1024,
attn_layers = Encoder(
dim = 256, # smaller hidden dimension
heads = 4, # less heads
ff_mult = 2, # smaller feedforward dimension
depth = 1
)
)
discriminator = TransformerWrapper(
num_tokens = 20000,
emb_dim = 128,
max_seq_len = 1024,
attn_layers = Encoder(
dim = 1024,
heads = 16,
ff_mult = 4,
depth = 12
)
)
# (2) weight tie the token and positional embeddings of generator and discriminator
generator.token_emb = discriminator.token_emb
generator.pos_emb = discriminator.pos_emb
# weight tie any other embeddings if available, token type embeddings, etc.
# (3) instantiate COCO
trainer = COCO(
generator,
discriminator,
discr_dim = 1024, # the embedding dimension of the discriminator
discr_layer = 'norm', # the layer name in the discriminator, whose output would be used for predicting token is still the same or replaced
cls_token_id = 1, # a token id must be reserved for [CLS], which is prepended to the sequence for contrastive learning
mask_token_id = 2, # the token id reserved for masking
pad_token_id = 0, # the token id for padding
mask_prob = 0.15, # masking probability for masked language modeling
mask_ignore_token_ids = [], # ids of tokens to ignore for mask modeling ex. (cls, sep)
cl_weight = 1., # weight for the contrastive learning loss
disc_weight = 1., # weight for the corrective learning loss
gen_weight = 1. # weight for the MLM loss
)
# (4) train
data = torch.randint(0, 20000, (1, 1024))
loss = trainer(data)
loss.backward()
# after much training, the discriminator should have improved
torch.save(discriminator, f'./pretrained-model.pt')
Citations
@misc{meng2021cocolm,
title = {COCO-LM: Correcting and Contrasting Text Sequences for Language Model Pretraining},
author = {Yu Meng and Chenyan Xiong and Payal Bajaj and Saurabh Tiwary and Paul Bennett and Jiawei Han and Xia Song},
year = {2021},
eprint = {2102.08473},
archivePrefix = {arXiv},
primaryClass = {cs.CL}
}
.\lucidrains\coco-lm-pytorch\setup.py
# 导入设置工具和查找包工具
from setuptools import setup, find_packages
# 设置包的信息
setup(
name = 'coco-lm-pytorch', # 包名
packages = find_packages(), # 查找所有包
version = '0.0.2', # 版本号
license='MIT', # 许可证
description = 'COCO - Pytorch', # 描述
author = 'Phil Wang', # 作者
author_email = 'lucidrains@gmail.com', # 作者邮箱
url = 'https://github.com/lucidrains/coco-lm-pytorch', # 项目链接
keywords = [ # 关键词列表
'transformers',
'artificial intelligence',
'deep learning',
'pretraining'
],
install_requires=[ # 安装依赖
'torch>=1.6.0',
'einops',
'x-transformers'
],
classifiers=[ # 分类器列表
'Development Status :: 4 - Beta',
'Intended Audience :: Developers',
'Topic :: Scientific/Engineering :: Artificial Intelligence',
'License :: OSI Approved :: MIT License',
'Programming Language :: Python :: 3.7',
],
)
coffee-neural-network
a simple neural network in coffeescript
running
$ npm install
$ npm install coffee-script -g
$ coffee nn.coffee
.\lucidrains\CoLT5-attention\colt5_attention\attend.py
# 导入必要的库
from collections import namedtuple
from functools import wraps
from packaging import version
import torch
from torch import nn, einsum
import torch.nn.functional as F
from einops import rearrange
# 定义一个命名元组 Config,用于存储 EfficientAttention 的配置信息
Config = namedtuple('EfficientAttentionConfig', ['enable_flash', 'enable_math', 'enable_mem_efficient'])
# 定义一个辅助函数,用于检查值是否存在
def exists(val):
return val is not None
# 定义一个装饰器函数,确保被装饰的函数只执行一次
def once(fn):
called = False
@wraps(fn)
def inner(x):
nonlocal called
if called:
return
called = True
return fn(x)
return inner
# 定义一个打印函数,确保只打印一次
print_once = once(print)
# 主要的 Attend 类
class Attend(nn.Module):
def __init__(
self,
dropout = 0.,
causal = False,
use_flash = False
):
super().__init__()
self.dropout = dropout
self.attn_dropout = nn.Dropout(dropout)
self.causal = causal
self.register_buffer("mask", None, persistent=False)
self.use_flash = use_flash
assert not (use_flash and version.parse(torch.__version__) < version.parse('2.0.0')), 'in order to use flash attention, you must be using pytorch 2.0 or above'
# 确定在 cuda 和 cpu 上的高效注意力配置
self.cpu_config = Config(True, True, True)
self.cuda_config = None
if not torch.cuda.is_available() or not use_flash:
return
device_properties = torch.cuda.get_device_properties(torch.device('cuda'))
if device_properties.major == 8 and device_properties.minor == 0:
print_once('A100 GPU detected, using flash attention if input tensor is on cuda')
self.cuda_config = Config(True, False, False)
else:
print_once('Non-A100 GPU detected, using math or mem efficient attention if input tensor is on cuda')
self.cuda_config = Config(False, True, True)
# 获取掩码
def get_mask(self, n, device):
if exists(self.mask) and self.mask.shape[-1] >= n:
return self.mask[:n, :n]
mask = torch.ones((n, n), device=device, dtype=torch.bool).triu(1)
self.register_buffer("mask", mask, persistent=False)
return mask
# Flash Attention 函数
def flash_attn(self, q, k, v, mask = None):
_, heads, q_len, _, k_len, is_cuda = *q.shape, k.shape[-2], q.is_cuda
# 检查掩码是否存在并扩展到兼容的形状
if exists(mask) and mask.ndim != 4:
mask = rearrange(mask, 'b j -> b 1 1 j')
mask = mask.expand(-1, heads, q_len, -1)
# 检查是否有兼容的设备用于 Flash Attention
config = self.cuda_config if is_cuda else self.cpu_config
# 使用 pytorch 2.0 的 flash attention
with torch.backends.cuda.sdp_kernel(**config._asdict()):
out = F.scaled_dot_product_attention(
q, k, v,
attn_mask = mask,
dropout_p = self.dropout if self.training else 0.,
is_causal = self.causal
)
return out
# 定义一个前向传播函数,接受查询(q), 键(k), 值(v)以及可选的掩码(mask)
"""
einstein notation
b - batch
h - heads
n, i, j - sequence length (base sequence length, source, target)
d - feature dimension
"""
# 获取查询(q)的序列长度(n)和设备信息(device)
n, device = q.shape[-2], q.device
# 计算缩放因子,根据特征维度的平方根
scale = q.shape[-1] ** -0.5
# 如果使用闪回注意力机制,则调用flash_attn函数
if self.use_flash:
return self.flash_attn(q, k, v, mask = mask)
# 计算相似度矩阵
sim = einsum("b h i d, b h j d -> b h i j", q, k) * scale
# 键的填充掩码
if exists(mask):
if mask.ndim != 4:
mask = rearrange(mask, 'b j -> b 1 1 j')
sim = sim.masked_fill(~mask, -torch.finfo(sim.dtype).max)
# 因果掩码
if self.causal:
causal_mask = self.get_mask(n, device)
sim = sim.masked_fill(causal_mask, -torch.finfo(sim.dtype).max)
# 注意力权重计算
attn = sim.softmax(dim=-1)
attn = self.attn_dropout(attn)
# 聚合值
out = einsum("b h i j, b h j d -> b h i d", attn, v)
return out
.\lucidrains\CoLT5-attention\colt5_attention\coor_descent.py
# 导入 torch 库
import torch
# 导入 torch.nn.functional 模块,并重命名为 F
import torch.nn.functional as F
# 从 torch.cuda.amp 模块中导入 autocast 函数
from torch.cuda.amp import autocast
# 从 einops 库中导入 rearrange 函数
from einops import rearrange
# 定义函数,判断变量是否存在
def exists(val):
return val is not None
# 定义函数,返回 val 或者默认值 d
def default(val, d):
return val if exists(val) else d
# 定义函数,计算输入张量的对数,避免值过小
def log(t, eps = 1e-20):
return torch.log(t.clamp(min = eps))
# 使用 autocast 装饰器,设置自动混合精度为关闭
@autocast(enabled = False)
# 定义坐标下降函数
def coor_descent(
s,
*,
n_iters,
k,
eps = 1e-1,
eps_init = None,
eps_decay = 1.,
mask = None
):
"""
coordinate descent - https://arxiv.org/abs/1502.04759, utilized in https://arxiv.org/abs/2303.09752
ε-scaling - https://arxiv.org/abs/1610.06519, utilized in https://arxiv.org/abs/2304.04947
in a follow up paper applying coordinate descent routing to efficient fine tuning
they were able to cut n_iters from 50 -> 20 by setting eps_init = 4 and eps_decay = 0.7
eps was dependent on the task, and ranged from 0.02 to 1
"""
# 断言迭代次数大于 0
assert n_iters > 0
# 定义 mask_value 为 s 数据类型的最小值
mask_value = -torch.finfo(s.dtype).max
# 如果 k 不是 torch.Tensor 类型,则将其转换为 torch.Tensor 类型
if not isinstance(k, torch.Tensor):
k = torch.Tensor([k]).to(s)
else:
k = rearrange(k, '... -> ... 1')
# 计算 k 的对数
logk = log(k)
# 如果 mask 存在,则用 mask_value 填充 s
if exists(mask):
s = s.masked_fill(~mask, mask_value)
# 初始化 a 和 b
a = 0
b = -s
# 初始化当前的 epsilon 值
current_eps = max(default(eps_init, eps), eps)
# 迭代 n_iters 次
for _ in range(n_iters):
# 计算 sb
sb = ((s + b) / current_eps)
# 如果 mask 存在,则用 mask_value 填充 sb
if exists(mask):
sb = sb.masked_fill(~mask, mask_value)
# 更新 a 和 b
a = current_eps * (logk - sb.logsumexp(dim = -1, keepdim = True))
b = -F.relu(s + a)
# 更新当前的 epsilon 值
current_eps = max(current_eps * eps_decay, eps)
# 计算分数
scores = ((s + a + b) / current_eps).exp()
# 如果 mask 存在,则用 0 填充 scores
if exists(mask):
scores = scores.masked_fill(~mask, 0.)
# 返回分数
return scores
.\lucidrains\CoLT5-attention\colt5_attention\topk.py
import torch
from torch.cuda.amp import autocast
from collections import namedtuple
from colt5_attention.coor_descent import coor_descent
TopkReturn = namedtuple('TopkReturn', ['values', 'indices', 'coor_descent_values', 'gates'])
@autocast(enabled = False)
def topk(
x,
k,
coor_descent_k_ratio = 9 / 8,
n_iters = 20,
eps = 1e-1,
eps_init = None,
eps_decay = 1.,
mask = None,
fused = False,
non_differentiable = False
):
"""
differentiable top-k on last dimension
"""
if non_differentiable:
# 如果不需要进行微分计算,则直接使用 torch.topk 函数获取前 k 个值和索引
values, indices = torch.topk(x, k = k, dim = -1)
return TopkReturn(values, indices, None, None)
assert coor_descent_k_ratio >= 1.
assert k > 0
# whether to used fused kernel or not
fn = coor_descent
if fused and x.is_cuda:
# 如果开启了 fused 选项并且在 GPU 上,则使用 triton_coor_descent 函数
from colt5_attention.triton_coor_descent import triton_coor_descent
fn = triton_coor_descent
# do coordinate descent for gradients
# 对梯度进行坐标下降优化
coor_descent_out = fn(
x,
k = min(k * coor_descent_k_ratio, x.shape[-1]), # 获取稍多一点以获得更好的学习效果,如 CoLT5 论文中所述(他们获取了 9/8 倍)
mask = mask,
n_iters = n_iters,
eps = eps,
eps_init = eps_init,
eps_decay = eps_decay
)
# do straight through
# 执行直通操作
gates = coor_descent_out + (1 - coor_descent_out).detach()
x = x * gates
# hard topk
# 使用 torch.topk 函数获取前 k 个值和索引
values, indices = torch.topk(x, k, dim = -1)
# return something that looks like a usual topk, but now differentiable
# 返回类似于常规 topk 的结果,但现在是可微分的
coor_descent_values = coor_descent_out.gather(-1, indices)
gates = gates.gather(-1, indices)
return TopkReturn(values, indices, coor_descent_values, gates)
.\lucidrains\CoLT5-attention\colt5_attention\transformer_block.py
import math
from functools import partial
from collections import namedtuple
import torch
import torch.nn.functional as F
from torch import Tensor, nn, einsum
from typing import Tuple, Optional
from local_attention import LocalMHA
from einops import rearrange, repeat, pack, unpack
from colt5_attention.attend import Attend
# helper functions
# 检查变量是否存在
def exists(val):
return val is not None
# 返回默认值
def default(val, d):
return val if exists(val) else d
# 检查是否可以被整除
def divisible_by(numer, denom):
return (numer % denom) == 0
# 将张量打包成指定模式
def pack_one(t, pattern):
return pack([t], pattern)
# 将打包的张量解包成指定模式
def unpack_one(t, ps, pattern):
return unpack(t, ps, pattern)[0]
# 将张量填充到指定的倍数
def pad_to_multiple(tensor, multiple, dim=-1, value=0):
seq_len = tensor.shape[dim]
m = seq_len / multiple
if m.is_integer():
return tensor, seq_len
remainder = math.ceil(m) * multiple - seq_len
pad_offset = (0,) * (-1 - dim) * 2
padded_tensor = F.pad(tensor, (*pad_offset, 0, remainder), value=value)
return padded_tensor, seq_len
# 从张量中按照索引获取数据
def batched_gather(x, indices):
batch_range = create_batch_range(indices, indices.ndim - 1)
return x[batch_range, indices]
# 返回输入张量本身
def identity(t):
return t
# 对张量进行 L2 归一化
def l2norm(t):
return F.normalize(t, dim=-1)
# tensor helpers
# 创建批次范围
def create_batch_range(t, right_pad_dims=1):
b, device = t.shape[0], t.device
batch_range = torch.arange(b, device=device)
pad_dims = ((1,) * right_pad_dims)
return batch_range.reshape(-1, *pad_dims)
# rotary positional embeddign
# https://arxiv.org/abs/2104.09864
# 旋转位置嵌入
class RotaryEmbedding(nn.Module):
def __init__(self, dim, theta=10000):
super().__init__()
inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer('inv_freq', inv_freq)
@property
def device(self):
return next(self.buffers()).device
def forward(self, seq_len):
t = torch.arange(seq_len, device=self.device).type_as(self.inv_freq)
freqs = torch.einsum('i , j -> i j', t, self.inv_freq)
freqs = torch.cat((freqs, freqs), dim=-1)
return freqs
# 旋转张量的一半
def rotate_half(x):
x1, x2 = x.chunk(2, dim=-1)
return torch.cat((-x2, x1), dim=-1)
# 应用旋转位置嵌入
def apply_rotary_pos_emb(pos, t):
return t * pos.cos() + rotate_half(t) * pos.sin()
# normalization
# RMS 归一化
class RMSNorm(nn.Module):
def __init__(self, dim):
super().__init__()
self.scale = dim ** 0.5
self.gamma = nn.Parameter(torch.ones(dim))
def forward(self, x):
normed = F.normalize(x, dim=-1)
return normed * self.scale * self.gamma
# modules
# 前馈神经网络
def FeedForward(dim, mult=4):
dim_hidden = int(dim * mult)
return nn.Sequential(
RMSNorm(dim),
nn.Linear(dim, dim_hidden),
nn.GELU(),
nn.Linear(dim_hidden, dim)
)
# 自注意力机制
class SelfAttention(nn.Module):
def __init__(
self,
dim,
dim_head=64,
heads=8,
use_flash=False,
prenorm=False
):
super().__init__()
self.heads = heads
self.scale = dim_head ** -0.5
dim_hidden = dim_head * heads
self.norm = RMSNorm(dim) if prenorm else nn.Identity()
self.attend = Attend(use_flash=use_flash)
self.to_qkv = nn.Linear(dim, dim_hidden * 3, bias=False)
self.to_out = nn.Linear(dim_hidden, dim, bias=False)
def forward(self, x):
h = self.heads
x = self.norm(x)
# 获取查询、键、值
q, k, v = self.to_qkv(x).chunk(3, dim=-1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), (q, k, v))
# 注意力
out = self.attend(q, k, v)
# 合并头部
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
class Attention(nn.Module):
def __init__(
self,
dim,
dim_head=64,
heads=8,
multiply_keys_by_score=False,
use_flash=False
# 调用父类的初始化方法
super().__init__()
# 初始化头数和头维度的比例
self.heads = heads
self.scale = dim_head ** -0.5
# 计算隐藏层维度
dim_hidden = dim_head * heads
# 设置是否使用乘以键的分数
self.multiply_keys_by_score = multiply_keys_by_score
# 初始化 RMS 归一化层
self.norm = RMSNorm(dim)
# 初始化空键值对参数
self.null_kv = nn.Parameter(torch.randn(2, heads, dim_head))
# 初始化 Attend 层
self.attend = Attend(use_flash = use_flash)
# 初始化将输入转换为查询向量的线性层
self.to_q = nn.Linear(dim, dim_hidden, bias = False)
# 初始化将输入转换为键值对向量的线性层
self.to_kv = nn.Linear(dim, dim_hidden * 2, bias = False)
# 初始化将输出转换为隐藏层向量的线性层
self.to_out = nn.Linear(dim_hidden, dim, bias = False)
# 前向传播方法
def forward(
self,
x,
context = None,
mask = None,
normalized_scores_kv = None,
normalized_scores_q = None,
rotary_emb: Optional[Tuple[Tensor, Tensor]] = None
):
"""
einops:
b - batch
h - heads, or number of heads per route
r - routing dimension, for routing different sets of key / values - should be more expressive
n - sequence dimension
d - head dimension
i - input model dimension
"""
# 获取输入张量 x 的 batch 大小和头数
batch, h = x.shape[0], self.heads
# 对输入张量 x 进行归一化处理
x = self.norm(x)
# 如果存在上下文张量 context,则对其进行归一化处理
if exists(context):
context = self.norm(context)
# 如果不存在上下文张量,则将其设为输入张量 x
context = default(context, x)
# 如果上下文张量的维度为 3,则在第二维度上添加一个维度
if context.ndim == 3:
context = rearrange(context, 'b n d -> b 1 n d')
# 如果存在归一化后的得分张量 normalized_scores_kv 且为 torch.Tensor 类型
if exists(normalized_scores_kv) and isinstance(normalized_scores_kv, torch.Tensor):
# 如果 normalized_scores_kv 的维度为 2,则在第二维度上添加一个维度
if normalized_scores_kv.ndim == 2:
normalized_scores_kv = rearrange(normalized_scores_kv, 'b n -> b 1 n')
# 重新排列 normalized_scores_kv 的维度
normalized_scores_kv = rearrange(normalized_scores_kv, 'b r n -> b r 1 n 1')
# 获取上下文张量的 key / value 路由数
num_kv_routes = context.shape[1]
# 获取查询张量 q
q = self.to_q(x)
q = rearrange(q, 'b n (h d) -> b h n d', h = h)
# 如果存在归一化后的查询得分张量 normalized_scores_q 且为 torch.Tensor 类型
if exists(normalized_scores_q) and isinstance(normalized_scores_q, torch.Tensor):
# 将查询张量 q 乘以归一化后的查询得分张量 normalized_scores_q
q = q * rearrange(normalized_scores_q, 'b n -> b 1 n 1')
# 处理 key / value,使用路由维度,在路由之间分配头数
assert divisible_by(h, num_kv_routes), 'number of heads must be divisible by the number of key / value routes'
heads_per_route = h // num_kv_routes
# 重新排列 key / value 权重张量的维度
kv_weight = rearrange(self.to_kv.weight, '(r h d) i -> r h d i', h = heads_per_route, r = num_kv_routes)
# 计算 key / value
kv = einsum('r h d i, b r n i -> b r h n d', kv_weight, context)
k, v = kv.chunk(2, dim = -1)
# 如果存在归一化后的 key / value 得分张量
if exists(normalized_scores_kv):
# 将 value 乘以归一化后的 key / value 得分张量
v = v * normalized_scores_kv
# 如果需要将 key 乘以得分
if self.multiply_keys_by_score:
k = k * normalized_scores_kv
# 如果存在旋转嵌入
if exists(rotary_emb):
q_rotary_emb, k_rotary_emb = rotary_emb
q = apply_rotary_pos_emb(q_rotary_emb, q)
# 如果 k_rotary_emb 的维度为 4
if k_rotary_emb.ndim == 4:
k_rotary_emb = repeat(k_rotary_emb, 'b 1 n d -> b r 1 n d', r = k.shape[1])
k = apply_rotary_pos_emb(k_rotary_emb, k)
# 合并 key / value 的路由维度和头数
k, v = map(lambda t: rearrange(t, 'b r h n d -> b (r h) n d'), (k, v))
# 空 key / value
nk, nv = map(lambda t: repeat(t, 'h d -> b h 1 d', b = batch), self.null_kv)
# 拼接 key / value
k = torch.cat((nk, k), dim = -2)
v = torch.cat((nv, v), dim = -2)
# 掩码
if exists(mask):
if mask.ndim == 3:
mask = repeat(mask, 'b r j -> b (r h) 1 j', h = heads_per_route)
else:
mask = rearrange(mask, 'b j -> b 1 1 j')
mask = F.pad(mask, (1, 0), value = True)
# 注意力
out = self.attend(q, k, v, mask = mask)
# 合并头数
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
# 导入所需的模块和函数
from colt5_attention.coor_descent import coor_descent
# 定义一个命名元组,用于存储路由器返回的结果
RouterReturn = namedtuple('RouterReturn', ['indices', 'scores', 'routed_tokens', 'routed_mask'])
# 定义一个路由器类,实现坐标下降算法
class CoordinateDescentRouter(nn.Module):
"""
from Wright et al. https://arxiv.org/abs/1502.04759
then adopted by https://arxiv.org/abs/2211.01267 for multi-vector document retrieval by Qian et al
finally, used successfully by this paper for routing to heavy branch attention / feedforward
"""
def __init__(
self,
dim,
straight_through = True,
n_iters = 20, # 使用20次迭代,采用ε-scaling
fetch_k_ratio = 9 / 8, # 在论文中,稍微增加k(乘以这个比率)以获得更好的学习效果
eps = 0.03, # 坐标下降的ε值。在最近的一篇论文中,文本使用0.03,语音使用1.0
eps_decay = 0.7,
eps_init = 4.,
num_routing_tokens = 1,
learned_routing_tokens = False,
use_triton = False,
cosine_sim_routing = False,
cosine_sim_scale = 8,
route_block_size = None,
triton_checkpoint_segments = None # 是否将坐标下降重新计算为多个段,使用4和50次迭代,向后加速3倍,牺牲前向和一些内存以保存初始a和b
):
super().__init__()
assert fetch_k_ratio >= 1.
self.n_iters = n_iters
self.fetch_k_ratio = fetch_k_ratio
self.coor_descent = coor_descent
# 与ε-scaling相关的超参数
self.eps = eps
self.eps_decay = eps_decay
self.eps_init = eps_init
if use_triton:
from colt5_attention.triton_coor_descent import triton_coor_descent
triton_checkpoint_segments = default(triton_checkpoint_segments, n_iters // 5)
self.coor_descent = partial(triton_coor_descent, checkpoint_segments = triton_checkpoint_segments)
self.is_one_routing_token = num_routing_tokens == 1
self.num_routing_tokens = num_routing_tokens
self.route_block_size = route_block_size
self.routing_token = nn.Parameter(torch.randn(num_routing_tokens, dim)) if not learned_routing_tokens else None
self.straight_through = straight_through
# 是否使用余弦相似度进行路由
self.cosine_sim_routing = cosine_sim_routing
self.cosine_sim_scale = cosine_sim_scale
# 将路由后的结果还原到原始张量中
def route_back(self, src, routed_tokens, indices):
batch_range = create_batch_range(routed_tokens)
src[batch_range, indices] = routed_tokens
return src
# 前向传播函数
def forward(
self,
x,
*,
num_tokens,
mask = None,
random_route = False,
routing_tokens = None,
keep_one_route_dim = False # 如果只有一个路由,是否保持维度
# 主要类
# 有条件的路由前馈网络
class ConditionalRoutedFeedForward(nn.Module):
def __init__(
self,
dim,
*,
num_heavy_tokens,
light_ff_mult = 0.5,
heavy_ff_mult = 4,
router_straight_through = True, # 确保所有归一化分数为1,仍可微分
router_kwargs: dict = {},
use_triton = False
):
super().__init__()
self.num_heavy_tokens = num_heavy_tokens
if use_triton:
router_kwargs = {**router_kwargs, 'use_triton': True}
# 初始化路由器
self.router = CoordinateDescentRouter(
dim = dim,
straight_through = router_straight_through,
**router_kwargs
)
# 初始化轻量级前馈网络和重量级前馈网络
self.light_ff = FeedForward(dim, light_ff_mult)
self.heavy_ff = FeedForward(dim, heavy_ff_mult)
# 前向传播函数
def forward(
self,
x,
mask = None,
num_heavy_tokens = None
):
# 获取输入张量的设备信息和重要令牌数量
device, num_heavy_tokens = x.device, default(num_heavy_tokens, self.num_heavy_tokens)
# 轻量级前馈网络看到所有令牌(隐藏维度仅为模型维度的1/2)
light_out = self.light_ff(x)
# 适当路由令牌到重型分支
indices, normalized_scores, routed_tokens, _ = self.router(x, num_tokens=num_heavy_tokens, mask=mask)
# 仅使用路由的令牌进行更重的分支
routed_tokens_out = self.heavy_ff(routed_tokens) * rearrange(normalized_scores, '... -> ... 1')
# 将重型前馈分支的输出散回
if exists(indices):
heavy_out = torch.zeros_like(x)
heavy_out = self.router.route_back(heavy_out, routed_tokens_out, indices)
else:
heavy_out = routed_tokens_out
# 将轻量级和重型分支相加并返回结果
return light_out + heavy_out
class ConditionalRoutedAttention(nn.Module):
# 定义一个条件路由注意力的类,继承自 nn.Module
def __init__(
self,
dim,
*,
num_heavy_tokens_q,
num_heavy_tokens_kv,
num_routed_kv = 1,
light_dim_head = 64,
light_heads = 8,
light_window_size = 128, # 每个令牌左右各看 ~ 64 个令牌
heavy_dim_head = 64,
heavy_heads = 8,
router_straight_through = True, # 确保所有归一化分数为 1,仍可微分
router_kwargs: dict = {},
multiply_keys_by_score = False,
multiply_queries_by_score = False,
use_triton = False,
use_null_q_tokens = True,
use_flash_attn = False,
rotary_emb = False
):
super().__init__()
if use_triton:
router_kwargs = {**router_kwargs, 'use_triton': True}
self.num_heavy_tokens_q = num_heavy_tokens_q
self.num_heavy_tokens_kv = num_heavy_tokens_kv
self.multiply_queries_by_score = multiply_queries_by_score
self.light_attn = LocalMHA(
dim = dim,
dim_head = light_dim_head,
heads = light_heads,
window_size = light_window_size // 2,
prenorm = True,
causal = False,
use_rotary_pos_emb = False,
look_backward = 1,
look_forward = 1
)
self.null_q_token = None
if use_null_q_tokens:
self.null_q_token = nn.Parameter(torch.randn(dim)) # 为未被路由器选择的查询令牌提供一个学习到的输出嵌入
self.q_router = CoordinateDescentRouter(
dim = dim,
straight_through = router_straight_through,
**router_kwargs
)
self.kv_router = CoordinateDescentRouter(
dim = dim,
num_routing_tokens = num_routed_kv,
straight_through = router_straight_through,
**router_kwargs
)
self.heavy_attn = Attention(
dim = dim,
dim_head = heavy_dim_head,
heads = heavy_heads,
multiply_keys_by_score = multiply_keys_by_score,
use_flash = use_flash_attn
)
# 旋转嵌入
self.rotary_emb = RotaryEmbedding(heavy_dim_head) if rotary_emb else None
def forward(
self,
x,
*,
num_heavy_tokens_q = None,
num_heavy_tokens_kv = None,
mask = None
):
# 解包输入张量的批次大小、序列长度和设备信息
batch, seq, device = *x.shape[:2], x.device
# 设置查询和键值中的重要令牌数量,默认为模型中定义的数量
num_heavy_tokens_q = default(num_heavy_tokens_q, self.num_heavy_tokens_q)
num_heavy_tokens_kv = default(num_heavy_tokens_kv, self.num_heavy_tokens_kv)
# 轻量级局部注意力机制查看有限上下文中的所有令牌
light_out = self.light_attn(x, mask = mask)
# 适当路由令牌以供重型分支使用
indices_q, normalized_scores_q, routed_tokens_q, _ = self.q_router(x, num_tokens = num_heavy_tokens_q, mask = mask)
indices_kv, normalized_scores_kv, routed_tokens_kv, routed_tokens_kv_mask = self.kv_router(x, num_tokens = num_heavy_tokens_kv, mask = mask)
# 如果指定了旋转嵌入,则获取旋转嵌入
rotary_emb = None
if exists(self.rotary_emb):
seq_rotary_emb = self.rotary_emb(seq)
q_rotary_emb = rearrange(seq_rotary_emb[indices_q], 'b n d -> b 1 n d') if exists(indices_q) else seq_rotary_emb
k_rotary_emb = rearrange(seq_rotary_emb[indices_kv], '... n d -> ... 1 n d') if exists(indices_kv) else seq_rotary_emb
rotary_emb = (q_rotary_emb, k_rotary_emb)
# 使用仅路由令牌的重型分支
routed_tokens_out = self.heavy_attn(
routed_tokens_q,
mask = routed_tokens_kv_mask,
context = routed_tokens_kv,
rotary_emb = rotary_emb,
normalized_scores_kv = normalized_scores_kv,
normalized_scores_q = normalized_scores_q if self.multiply_queries_by_score else None
)
routed_tokens_out = routed_tokens_out * rearrange(normalized_scores_q, '... -> ... 1')
# 将重型分支的输出散回
if exists(indices_q):
if exists(self.null_q_token):
heavy_out = rearrange(self.null_q_token, 'd -> 1 1 d')
heavy_out = heavy_out.expand_as(x).clone()
else:
heavy_out = torch.zeros_like(x)
heavy_out = self.q_router.route_back(heavy_out, routed_tokens_out, indices_q)
else:
heavy_out = routed_tokens_out
# 汇总轻量级和重量级分支的输出
return light_out + heavy_out
# 定义一个条件路由的图像特征映射注意力模块
class ConditionalRoutedImageAttention(nn.Module):
def __init__(
self,
dim,
*,
num_heavy_tokens_q,
num_heavy_tokens_kv,
num_routed_kv = 1,
light_dim_head = 64,
light_heads = 8,
light_window_size = 128, # 每个令牌左右各看大约 64 个令牌
heavy_dim_head = 64,
heavy_heads = 8,
router_straight_through = True, # 确保所有归一化分数为 1,仍然可微分
router_kwargs: dict = {},
multiply_keys_by_score = False,
multiply_queries_by_score = False,
use_triton = False,
use_null_q_tokens = True,
use_flash_attn = False,
channel_first = False
):
super().__init__()
self.channel_first = channel_first
# 如果使用 Triton,设置 router_kwargs 中的 'use_triton' 为 True
if use_triton:
router_kwargs = {**router_kwargs, 'use_triton': True}
self.num_heavy_tokens_q = num_heavy_tokens_q
self.num_heavy_tokens_kv = num_heavy_tokens_kv
self.multiply_queries_by_score = multiply_queries_by_score
self.light_window_size = light_window_size
# 创建轻量级自注意力模块
self.light_attn = SelfAttention(
dim = dim,
dim_head = light_dim_head,
heads = light_heads,
prenorm = True
)
self.null_q_token = None
# 如果使用空查询令牌,为其创建一个学习到的输出嵌入
if use_null_q_tokens:
self.null_q_token = nn.Parameter(torch.randn(dim))
# 创建查询路由器
self.q_router = CoordinateDescentRouter(
dim = dim,
straight_through = router_straight_through,
**router_kwargs
)
# 创建键值路由器
self.kv_router = CoordinateDescentRouter(
dim = dim,
num_routing_tokens = num_routed_kv,
straight_through = router_straight_through,
**router_kwargs
)
# 创建重量级注意力模块
self.heavy_attn = Attention(
dim = dim,
dim_head = heavy_dim_head,
heads = heavy_heads,
multiply_keys_by_score = multiply_keys_by_score,
use_flash = use_flash_attn
)
def forward(
self,
x,
*,
num_heavy_tokens_q = None,
num_heavy_tokens_kv = None,
mask = None
):
# 断言输入张量 x 的维度为 4
assert x.ndim == 4
# 获取输入张量 x 的批大小、设备信息、是否通道优先、光窗口大小
batch, device, channel_first, w = x.shape[0], x.device, self.channel_first, self.light_window_size
# 如果通道优先,则重新排列张量 x 的维度
if channel_first:
x = rearrange(x, 'b d ... -> b ... d')
# 设置轻量级注意力机制中的重要令牌数量
num_heavy_tokens_q = default(num_heavy_tokens_q, self.num_heavy_tokens_q)
num_heavy_tokens_kv = default(num_heavy_tokens_kv, self.num_heavy_tokens_kv)
# 轻量级局部注意力机制看到有限上下文中的所有令牌
# 重新排列输入张量 x,以便进行轻量级注意力计算
light_input = rearrange(x, 'b (h p1) (w p2) d -> b h w (p1 p2) d', p1 = w, p2 = w)
x, ps = pack_one(light_input, '* n d')
# 使用轻量级注意力机制计算输出
light_out = self.light_attn(x)
light_out = unpack_one(light_out, ps, '* n d')
light_out = rearrange(light_out, 'b h w (p1 p2) d -> b (h p1) (w p2) d', p1 = w, p2 = w)
# 为重型分支适当路由令牌
# 使用查询路由器对输入张量 x 进行路由,获取相关信息
indices_q, normalized_scores_q, routed_tokens_q, _ = self.q_router(x, num_tokens = num_heavy_tokens_q, mask = mask)
# 使用键值路由器对输入张量 x 进行路由,获取相关信息
indices_kv, normalized_scores_kv, routed_tokens_kv, routed_tokens_kv_mask = self.kv_router(x, num_tokens = num_heavy_tokens_kv, mask = mask)
# 使用仅包含路由令牌的重型注意力机制进行计算
routed_tokens_out = self.heavy_attn(
routed_tokens_q,
mask = routed_tokens_kv_mask,
context = routed_tokens_kv,
normalized_scores_kv = normalized_scores_kv,
normalized_scores_q = normalized_scores_q if self.multiply_queries_by_score else None
)
routed_tokens_out = routed_tokens_out * rearrange(normalized_scores_q, '... -> ... 1')
# 将重型分支的输出散回
# 如果存在空查询令牌,则使用该令牌进行填充
if exists(self.null_q_token):
heavy_out = rearrange(self.null_q_token, 'd -> 1 1 d')
heavy_out = heavy_out.expand_as(x).clone()
else:
heavy_out = torch.zeros_like(x)
heavy_out = self.q_router.route_back(heavy_out, routed_tokens_out, indices_q)
heavy_out = unpack_one(heavy_out, ps, '* n d')
heavy_out = rearrange(heavy_out, 'b h w (p1 p2) d -> b (h p1) (w p2) d', p1 = w, p2 = w)
# 将轻量级和重型分支的输出相加
out = light_out + heavy_out
# 如果通道优先,则重新排列输出张量的维度
if channel_first:
out = rearrange(out, 'b ... d -> b d ...')
# 返回最终输出
return out
# 定义条件路由的自回归注意力模块
class ConditionalRoutedAutoregressiveAttention(nn.Module):
def __init__(
self,
dim,
*,
num_heavy_tokens_q,
num_heavy_tokens_kv,
num_routed_kv = 1,
light_dim_head = 64,
light_heads = 8,
light_window_size = 128, # 每个标记左右各看到 ~ 64 个标记
heavy_window_size = None,
heavy_dim_head = 64,
heavy_heads = 8,
router_straight_through = True, # 确保所有归一化分数为 1,仍可微分
router_kwargs: dict = {},
multiply_keys_by_score = False,
multiply_queries_by_score = False,
use_triton = False,
use_null_q_tokens = True,
use_flash_attn = False,
rotary_emb = False
):
super().__init__()
if use_triton:
router_kwargs = {**router_kwargs, 'use_triton': True}
self.num_heavy_tokens_q = num_heavy_tokens_q
self.num_heavy_tokens_kv = num_heavy_tokens_kv
self.multiply_queries_by_score = multiply_queries_by_score
self.heavy_window_size = default(heavy_window_size, light_window_size)
self.light_attn = LocalMHA(
dim = dim,
dim_head = light_dim_head,
heads = light_heads,
window_size = light_window_size,
prenorm = True,
causal = True,
exact_windowsize = False,
use_rotary_pos_emb = False
)
self.null_q_token = None
if use_null_q_tokens:
self.null_q_token = nn.Parameter(torch.randn(dim)) # 为未被路由器选择的查询标记提供一个学习到的输出嵌入
self.q_router = CoordinateDescentRouter(
dim = dim,
straight_through = router_straight_through,
**router_kwargs
)
self.kv_router = CoordinateDescentRouter(
dim = dim,
num_routing_tokens = num_routed_kv,
straight_through = router_straight_through,
**router_kwargs
)
self.heavy_attn = Attention(
dim = dim,
dim_head = heavy_dim_head,
heads = heavy_heads,
multiply_keys_by_score = multiply_keys_by_score,
use_flash = use_flash_attn
)
# 旋转嵌入
self.rotary_emb = RotaryEmbedding(heavy_dim_head) if rotary_emb else None
def forward(
self,
x,
*,
num_heavy_tokens_q = None,
num_heavy_tokens_kv = None,
random_route = False
# 调整条件路由的自注意力以适应交叉注意力
# 定义条件路由的交叉注意力模块
class ConditionalRoutedCrossAttention(nn.Module):
def __init__(
self,
dim,
*,
num_tokens_q,
num_tokens_kv,
num_sets_kv = 1, # 如果设置大于 1,将路由多组键/值,每组大小为 num_tokens_kv,使用这么多路由标记
dim_head = 64,
heads = 8,
router_straight_through = True, # 确保所有归一化分数为 1,仍可微分
router_kwargs: dict = {},
kv_routing_tokens = 1,
multiply_keys_by_score = False,
use_triton = False,
use_null_q_tokens = True,
use_flash_attn = False,
route_block_size = None
):
super().__init__()
if use_triton:
router_kwargs = {**router_kwargs, 'use_triton': True}
self.num_tokens_q = num_tokens_q
self.num_tokens_kv = num_tokens_kv
self.null_q_token = None
if use_null_q_tokens:
self.null_q_token = nn.Parameter(torch.randn(dim)) # 为未被路由器选择的查询标记提供一个学习到的输出嵌入
self.q_router = CoordinateDescentRouter(
dim = dim,
straight_through = router_straight_through,
**router_kwargs
)
self.kv_router = CoordinateDescentRouter(
dim = dim,
straight_through = router_straight_through,
num_routing_tokens = kv_routing_tokens,
route_block_size = route_block_size,
**router_kwargs
)
self.heavy_attn = Attention(
dim = dim,
dim_head = dim_head,
heads = heads,
multiply_keys_by_score = multiply_keys_by_score,
use_flash = use_flash_attn
)
def forward(
self,
x,
context,
*,
num_tokens_q = None,
num_tokens_kv = None,
mask = None,
context_mask = None
):
batch, device = x.shape[0], x.device
# route the queries
query_length = x.shape[-2]
num_tokens_q = default(num_tokens_q, self.num_tokens_q)
indices_q, normalized_scores_q, routed_tokens_q, _ = self.q_router(x, num_tokens = num_tokens_q, mask = mask)
# route the long contexts
key_value_length = context.shape[-2]
num_tokens_kv = default(num_tokens_kv, self.num_tokens_kv)
routed_tokens_kv = context
routed_tokens_kv_mask = context_mask
normalized_scores_kv = None
should_route_kv = key_value_length > num_tokens_kv
if should_route_kv:
indices_kv, normalized_scores_kv, routed_tokens_kv, routed_tokens_kv_mask = self.kv_router(context, num_tokens = num_tokens_kv, mask = context_mask)
# do the heavier branch with only routed tokens
routed_tokens_out = self.heavy_attn(
routed_tokens_q,
mask = routed_tokens_kv_mask,
context = routed_tokens_kv,
normalized_scores_kv = normalized_scores_kv
)
if should_route_queries:
routed_tokens_out = routed_tokens_out * rearrange(normalized_scores_q, '... -> ... 1')
# early return if queries did not undergo routing
if not should_route_queries:
return routed_tokens_out
# otherwise, scatter back the query outputs
if exists(self.null_q_token):
out = rearrange(self.null_q_token, 'd -> 1 1 d')
out = out.expand_as(x).clone()
else:
out = torch.zeros_like(x)
if exists(indices_q):
out = self.q_router.route_back(out, routed_tokens_out, indices_q)
return out
# 定义一个名为 ConditionalRoutedTransformerBlock 的类,继承自 nn.Module
class ConditionalRoutedTransformerBlock(nn.Module):
# 初始化函数,接受多个参数
def __init__(
self,
dim,
*,
num_heavy_attn_tokens_q,
num_heavy_attn_tokens_kv,
num_routed_kv = 1,
num_heavy_ff_tokens,
light_dim_head = 64,
light_heads = 8,
light_window_size = 128,
heavy_dim_head = 64,
heavy_heads = 8,
light_ff_mult = 0.5,
heavy_ff_mult = 4,
router_straight_through = True,
router_kwargs: dict = {},
multiply_keys_by_score = False,
multiply_queries_by_score = False,
use_triton = False,
use_null_q_tokens = True,
use_flash_attn = False
):
# 调用父类的初始化函数
super().__init__()
# 创建 ConditionalRoutedFeedForward 对象并赋值给 self.conditional_ff
self.conditional_ff = ConditionalRoutedFeedForward(
dim,
num_heavy_tokens = num_heavy_ff_tokens,
light_ff_mult = light_ff_mult,
heavy_ff_mult = heavy_ff_mult,
router_straight_through = router_straight_through,
router_kwargs = router_kwargs,
use_triton = use_triton
)
# 创建 ConditionalRoutedAttention 对象并赋值给 self.conditional_attn
self.conditional_attn = ConditionalRoutedAttention(
dim,
light_dim_head = light_dim_head,
light_heads = light_heads,
light_window_size = light_window_size,
heavy_dim_head = heavy_dim_head,
heavy_heads = heavy_heads,
num_heavy_tokens_q = num_heavy_attn_tokens_q,
num_heavy_tokens_kv = num_heavy_attn_tokens_kv,
num_routed_kv = num_routed_kv,
router_straight_through = router_straight_through,
router_kwargs = router_kwargs,
multiply_keys_by_score = multiply_keys_by_score,
multiply_queries_by_score = multiply_queries_by_score,
use_triton = use_triton,
use_null_q_tokens = use_null_q_tokens,
use_flash_attn = use_flash_attn
)
# 前向传播函数,接受多个参数
def forward(
self,
x,
mask = None,
num_heavy_attn_tokens_q = None,
num_heavy_attn_tokens_kv = None,
num_heavy_ff_tokens = None
):
# 调用 self.conditional_attn 进行注意力计算,并将结果与输入 x 相加
x = self.conditional_attn(x, mask = mask, num_heavy_tokens_q = num_heavy_attn_tokens_q, num_heavy_tokens_kv = num_heavy_attn_tokens_kv) + x
# 调用 self.conditional_ff 进行前馈计算,并将结果与输入 x 相加
x = self.conditional_ff(x, mask = mask, num_heavy_tokens = num_heavy_ff_tokens) + x
# 返回计算结果
return x
.\lucidrains\CoLT5-attention\colt5_attention\triton_coor_descent.py
# 从 math 模块中导入 log 函数
from math import log
# 导入 torch 模块及相关类和函数
import torch
from torch import Tensor
from torch import autograd
import torch.nn.functional as F
from torch.cuda.amp import autocast, custom_fwd, custom_bwd
# 从 colt5_attention 模块中导入 coor_descent 函数
from colt5_attention.coor_descent import coor_descent
# 从 einops 模块中导入 pack、unpack、repeat 函数
from einops import pack, unpack, repeat
# 尝试导入 triton 模块及相关类和函数
try:
import triton
import triton.language as tl
except ImportError as e:
# 如果导入失败,则打印提示信息
print('triton is not installed, please install by running `pip install triton -U --pre`')
# 退出程序
exit()
# 确保使用的是最新版本的 triton
# 导入版本模块,用于比较 triton 版本
from packaging import version
# 断言 triton 版本大于等于 '2.0'
assert version.parse(triton.__version__) >= version.parse('2.0')
# 辅助函数
# 判断变量是否存在
def exists(val):
return val is not None
# 如果变量存在则返回其值,否则返回默认值
def default(val, d):
return val if exists(val) else d
# 计算块大小对应的 warp 数量
def calc_num_warps(block_size):
num_warps = 4
if block_size >= 2048:
num_warps = 8
if block_size >= 4096:
num_warps = 16
return num_warps
# 将张量按照指定模式进行打包
def pack_one(t, pattern):
return pack([t], pattern)
# 将打包后的张量按照指定模式进行解包
def unpack_one(t, ps, pattern):
return unpack(t, ps, pattern)[0]
# 将数字分成指定组数
def num_to_groups(num, groups):
assert 0 < groups <= num
floor = num // groups
remainder = num % groups
out = []
for ind in range(groups):
out.append(floor + int(ind < remainder))
assert sum(out) == num
return out
# 前向传播
# 定义前向传播的 Triton 内核函数
@triton.jit
def coor_descent_kernel_forward(
a_ptr,
b_ptr,
input_ptr,
mask_ptr,
k_ptr,
a_iter_stride,
b_row_stride,
b_iter_stride,
input_row_stride,
mask_row_stride,
n_iters,
current_eps,
eps_decay,
eps,
n_cols,
BLOCK_SIZE: tl.constexpr
):
row_idx = tl.program_id(0)
col_offsets = tl.arange(0, BLOCK_SIZE)
col_mask = col_offsets < n_cols
# 加载 mask 作为整数(因为布尔值会导致 Triton 出错)
mask_start_ptr = mask_ptr + row_idx * mask_row_stride
mask_ptrs = mask_start_ptr + col_offsets
mask_ints = tl.load(mask_ptrs, mask = col_mask, other = 0)
mask = mask_ints == 1
# 加载 a 和 b
a_ptr = a_ptr + row_idx
a = tl.load(a_ptr)
b_start_ptr = b_ptr + row_idx * b_row_stride
b_ptrs = b_start_ptr + col_offsets
b = tl.load(b_ptrs, mask = col_mask, other = 0)
# 加载得分 s
row_start_ptr = input_ptr + row_idx * input_row_stride
input_ptrs = row_start_ptr + col_offsets
s = tl.load(input_ptrs, mask = mask, other = -float('inf'))
# 加载 k - 控制输出的稀疏性
k_ptr = k_ptr + row_idx
k = tl.load(k_ptr)
# 初始化一些常数
logk = tl.log(k)
for _ in range(n_iters):
a = (s + b) / current_eps
a = tl.where(mask, a, -float('inf'))
# 稳定的对数求和指数
a_max = tl.max(a, axis = 0)
a_minus_max = tl.where(mask, a - a_max, -float('inf'))
exp = tl.exp(a_minus_max)
sum_exp = tl.sum(exp, axis = 0)
log_sum_exp = tl.log(sum_exp) + a_max
a = current_eps * (logk - log_sum_exp)
# 更新 b
b = s + a
b = tl.where(b >= 0., -b, 0.)
# 衰减 epsilon,从 epsilon 缩放
current_eps *= eps_decay
if current_eps < eps:
current_eps = eps
# 存储 a 和 b 以备下一轮使用
next_a_ptrs = a_ptr + a_iter_stride
next_b_ptrs = b_ptrs + b_iter_stride
tl.store(next_a_ptrs, a)
tl.store(next_b_ptrs, b, mask = col_mask)
# 反向传播
# 定义反向传播的 Triton 内核函数
@triton.jit
def coor_descent_kernel_backward(
dk_ptr,
input_ptr,
a_ptr,
b_ptr,
mask_ptr,
ds_ptr,
db_ptr,
k_ptr,
last_da_ptr,
input_row_stride,
b_row_stride,
mask_row_stride,
ds_row_stride,
db_row_stride,
n_iters,
eps_init,
eps_decay,
eps,
n_cols,
BLOCK_SIZE: tl.constexpr
):
row_idx = tl.program_id(0)
col_offsets = tl.arange(0, BLOCK_SIZE)
# 加载和生成 mask
col_mask = col_offsets < n_cols
# 加载 mask 作为整数(因为布尔值会导致 Triton 出错)
mask_start_ptr = mask_ptr + row_idx * mask_row_stride
# 计算掩码指针
mask_ptrs = mask_start_ptr + col_offsets
# 从指定位置加载整数值
mask_ints = tl.load(mask_ptrs, mask = col_mask, other = 0)
# 创建布尔掩码
mask = mask_ints == 1
# 加载 a 和 b
# 更新 a 指针
a_ptr = a_ptr + row_idx
# 加载初始值 a
init_a = tl.load(a_ptr)
# 更新 b 起始指针
b_start_ptr = b_ptr + row_idx * b_row_stride
# 计算 b 指针
b_ptrs = b_start_ptr + col_offsets
# 加载初始值 b
init_b = tl.load(b_ptrs, mask = mask, other = 0)
# 加载输入
# 更新行起始指针
row_start_ptr = input_ptr + row_idx * input_row_stride
# 计算输入指针
input_ptrs = row_start_ptr + col_offsets
# 加载输入值
s = tl.load(input_ptrs, mask = mask, other = -float('inf'))
# 加载 k - 控制输出的稀疏性
# 更新 k 指针
k_ptr = k_ptr + row_idx
# 加载 k 值
k = tl.load(k_ptr)
# 计算 k 的自然对数
logk = tl.log(k)
# 加载上一个 da
# 更新上一个 da 指针
last_da_ptr = last_da_ptr + row_idx
# 加载上一个 da 值
last_da = tl.load(last_da_ptr)
# 加载初始 ds
# 更新 ds 行起始指针
ds_row_start_ptr = ds_ptr + row_idx * ds_row_stride
# 计算 ds 指针
ds_ptrs = ds_row_start_ptr + col_offsets
# 加载初始 ds 值
ds = tl.load(ds_ptrs, mask = mask, other = 0.)
# 加载初始 db
# 更新 db 行起始指针
db_row_start_ptr = db_ptr + row_idx * db_row_stride
# 计算 db 指针
db_ptrs = db_row_start_ptr + col_offsets
# 加载初始 db 值
db = tl.load(db_ptrs, mask = mask, other = 0.)
# 加载初始 dk
# 更新 dk 指针
dk_ptr = dk_ptr + row_idx
# 加载 dk 值
dk = tl.load(dk_ptr)
# 反向传播
for ind in range(n_iters):
a = init_a
b = init_b
sa = s * 0
softmax = s * 0
# 计算 epsilon
current_eps = eps_init / eps_decay
# 重新计算
for _ in range(n_iters - ind):
# 更新 epsilon
current_eps *= eps_decay
if current_eps < eps:
current_eps = eps
# 更新 a
sb = (s + b) / current_eps
sb = tl.where(mask, sb, -float('inf'))
# 稳定的对数求和指数
sb_max = tl.max(sb, axis = 0)
sb_minus_max = tl.where(mask, sb - sb_max, -float('inf'))
exp = tl.exp(sb_minus_max)
sum_exp = tl.sum(exp, axis = 0)
softmax = exp / sum_exp
log_sum_exp = tl.log(sum_exp) + sb_max
a = current_eps * (logk - log_sum_exp)
# 更新 b
sa = s + a
b = tl.where(sa > 0., -sa, 0.)
# 向后传播
dsa = db * tl.where(sa > 0, -1., 0.)
ds += dsa
da = tl.sum(dsa, axis = 0) + last_da
dk += da * current_eps
dsb = da * -softmax
ds += dsb
db = dsb
last_da *= 0.
# 存储 dk
tl.store(dk_ptr, dk)
# 存储 ds
tl.store(ds_ptrs, ds, mask = col_mask)
# 存储 db
tl.store(db_ptrs, db, mask = col_mask)
# 定义一个继承自autograd.Function的类_coor_descent,用于实现坐标下降算法
class _coor_descent(autograd.Function):
# 前向传播函数
@staticmethod
@custom_fwd
def forward(
ctx,
x,
n_iters,
k,
eps,
eps_init,
eps_decay,
mask,
checkpoint_segments
):
# 断言迭代次数大于0
assert n_iters > 0
# 断言输入张量在CUDA上
assert x.is_cuda, 'triton coordinate descent must be on cuda'
# 获取输入张量的批大小、是否需要梯度、设备和数据类型
batch, requires_grad, device, dtype = x.shape[0], x.requires_grad, x.device, x.dtype
# 如果mask不存在,则创建一个与x相同形状的全1张量
if not exists(mask):
mask = torch.ones_like(x, dtype=torch.bool, device=x.device)
# 将x和mask打包成一维张量
x, shape = pack_one(x, '* n')
mask, _ = pack_one(mask, '* n')
# 将x中mask为False的元素替换为最小值
x = x.masked_fill(~mask, -torch.finfo(x.dtype).max)
mask_ints = mask.int()
epsilons = []
eps_init = default(eps_init, eps)
current_eps = float(max(eps_init, eps))
n_rows, n_cols = x.shape
# 如果k是整数或浮点数,则创建一个全为k的张量
if isinstance(k, (int, float)):
k = torch.full((n_rows,), k)
# 断言k的元素数量与行数相同
assert k.numel() == n_rows
k = k.to(x)
BLOCK_SIZE = triton.next_power_of_2(n_cols)
# 断言BLOCK_SIZE小于等于131072
assert BLOCK_SIZE <= 131072, 'the maximum block size allowed is 131072 for triton cuda kernel - set the `route_block_size` for the CoordinateDescentRouter to be this value or less in order to uniformly route to get around this limitation'
num_warps = calc_num_warps(BLOCK_SIZE)
checkpointed_a = torch.empty((checkpoint_segments + 1, n_rows), device=device, dtype=dtype)
checkpointed_b = torch.empty((checkpoint_segments + 1, n_rows, n_cols), device=device, dtype=dtype)
checkpointed_a[0] = torch.zeros_like(k)
checkpointed_b[0] = -x
for ind, segment_iters in enumerate(num_to_groups(n_iters, checkpoint_segments)):
is_last = ind == (checkpoint_segments - 1)
epsilons.append(current_eps)
# 调用CUDA核函数进行坐标下降计算
coor_descent_kernel_forward[(n_rows,)](
checkpointed_a[ind],
checkpointed_b[ind],
x,
mask_ints,
k,
checkpointed_a.stride(0),
n_cols,
checkpointed_b.stride(0),
x.stride(0),
mask_ints.stride(0),
segment_iters,
current_eps,
eps_decay,
eps,
n_cols,
num_warps=num_warps,
BLOCK_SIZE=BLOCK_SIZE,
)
current_eps *= (eps_decay ** segment_iters)
current_eps = max(current_eps, eps)
last_a, last_b = map(lambda t: t[-1], (checkpointed_a, checkpointed_b))
y = torch.exp((last_a[..., None] + last_b + x) / current_eps)
epsilons.append(current_eps)
if requires_grad:
checkpointed_a = checkpointed_a[:-1]
checkpointed_b = checkpointed_b[:-1]
ctx.args = (n_iters, checkpoint_segments, epsilons, eps_decay, eps)
ctx.save_for_backward(x, y, k, mask, checkpointed_a, checkpointed_b)
y = unpack_one(y, shape, '* n')
return y
# 反向传播函数
@staticmethod
@custom_bwd
def backward(
ctx,
grad_probs
):
# 断言梯度概率是否在 GPU 上
assert grad_probs.is_cuda
# 获取批量大小
batch = grad_probs.shape[0]
# 从上下文中获取参数
n_iters, checkpoint_segments, epsilons, eps_decay, eps = ctx.args
x, y, k, mask, checkpointed_a, checkpointed_b = ctx.saved_tensors
# 将梯度概率打包成指定形状
grad_probs, shape = pack_one(grad_probs, '* n')
# 如果存在掩码,则将梯度概率中的非掩码部分置零
if exists(mask):
grad_probs = grad_probs.masked_fill(~mask, 0.)
# 获取梯度概率的行数和列数
n_rows, n_cols = grad_probs.shape
# 计算块大小
BLOCK_SIZE = triton.next_power_of_2(n_cols)
num_warps = calc_num_warps(BLOCK_SIZE)
# 解包 epsilon 值
*epsilons, last_eps = epsilons
# 计算 ds, db, dk, last_da
ds = grad_probs * y / last_eps
db = ds.clone()
dk = torch.zeros_like(k)
last_da = ds.sum(dim=-1)
# 将掩码转换为整数类型
mask_int = mask.int()
# 使用 zip 函数将多个迭代器的元素打包成元组
items = zip(
reversed(checkpointed_a.unbind(dim=0)),
reversed(checkpointed_b.unbind(dim=0)),
reversed(num_to_groups(n_iters, checkpoint_segments)),
reversed(epsilons)
)
# 遍历 items 中的元素
for ind, (init_a, init_b, segment_iters, eps_init) in enumerate(items):
is_first = ind == 0
# 调用 coor_descent_kernel_backward 函数
coor_descent_kernel_backward[(n_rows,)](
dk,
x,
init_a,
init_b,
mask_int,
ds,
db,
k,
last_da if is_first else torch.zeros_like(last_da),
x.stride(0),
init_b.stride(0),
mask_int.stride(0),
ds.stride(0),
db.stride(0),
segment_iters,
eps_init,
eps_decay,
eps,
n_cols,
num_warps=num_warps,
BLOCK_SIZE=BLOCK_SIZE
)
# 更新 ds
ds += -db
ds = unpack_one(ds, shape, '* n')
# 如果 k 不需要梯度,则将 dk 置为 None
if not k.requires_grad:
dk = None
else:
dk /= k
# 返回结果
return ds, None, dk, None, None, None, None, None
# 禁用自动类型转换的装饰器
@autocast(enabled = False)
# Triton 坐标下降算法
def triton_coor_descent(
s, # 输入张量
*,
n_iters, # 迭代次数
k, # 参数 k
eps = 1e-1, # 精度参数,默认为 0.1
eps_init = None, # 初始精度参数
eps_decay = 1., # 精度参数衰减率
mask = None, # 掩码
checkpoint_segments = 1 # 检查点段数
):
# 如果输入张量不在 CUDA 上,则使用普通的坐标下降算法
if not s.is_cuda:
return coor_descent(s, n_iters = n_iters, k = k, eps = eps, eps_init = eps_init, eps_decay = eps_decay, mask = mask)
# 在 CUDA 上使用自定义的坐标下降算法
return _coor_descent.apply(s, n_iters, k, eps, eps_init, eps_decay, mask, checkpoint_segments)