pytorch 使用 xformers 库 加速多头注意力计算 和 大幅节省显存

效果概览:
好处:使用 google PALM 架构的小模型做 生成任务,改为 xformers 实现后,加速比为 2倍,显存消耗为原来的 1/3 ,非常给力。
缺点:相比pytorch的原生实现,误差略大。。。

测试显卡:GTX1070,RTX3090

xformers 官方github仓库:https://github.com/facebookresearch/xformers
xformers 官方文档:https://facebookresearch.github.io/xformers/
https://facebookresearch.github.io/xformers/components/ops.html#module-xformers.ops

前两周 xformers 官方提供了 pypi 和 whl 轮包
windows 和 linux 均可用,最低版本要求 pytorch 1.13.1 版本

pip 安装 xformers

pip install -U xformers

如果需要用于编码器或需要位置偏置,则需要安装 0.17 以上版本
当前(2023/2/26) v0.17 为预发行版,需要使用 --pre 来安装

pip install --pre -U xformers

使用方法

import torch
from xformers.ops import memory_efficient_attention, LowerTriangularMask

device='cuda'
batch = 4
n_head = 8
head_dim = 16
seq_len = 128

q = torch.rand(batch, seq_len, n_head, head_dim).to(device)
k = torch.rand(batch, seq_len, n_head, head_dim).to(device)
v = torch.rand(batch, seq_len, n_head, head_dim).to(device)

# 使用 causal 偏置掩码
o = memory_efficient_attention(q, k, v, LowerTriangularMask())

# 不使用任何偏置掩码
o = memory_efficient_attention(q, k, v)

# 使用自定义的偏置掩码 attn_bias,要求 xformers 版本 大于等于 0.17
## 这里的 from_len,to_len 分别代表Decoder的序列长度,Encoder的序列长度
from_len = seq_len
to_len = seq_len
attn_bias = torch.rand(batch, n_head, from_len, to_len).to(device)
o = memory_efficient_attention(q, k, v, attn_bias)

memory_efficient_attention 的等效pytorch代码实现,用于模型导出
改自 https://facebookresearch.github.io/xformers/components/ops.html#module-xformers.ops
注:LowerTriangularMask 之类的掩码请另外手动生成一个等效的 attn_bias Tensor 再用于本函数

import torch.nn.functional as F

def memory_efficient_attention_pytorch(query, key, value, attn_bias=None, p=0., scale=None):
    # query     [batch, seq_len, n_head, head_dim]
    # key       [batch, seq_len, n_head, head_dim]
    # value     [batch, seq_len, n_head, head_dim]
    # attn_bias [batch, n_head, seq_len, seq_len]

    if scale is None:
        scale = 1 / query.shape[-1] ** 0.5
    
    # BLHC -> BHLC
    query = query.transpose(1, 2)
    key = key.transpose(1, 2)
    value = value.transpose(1, 2)

    query = query * scale
    # BHLC @ BHCL -> BHLL
    attn = query @ key.transpose(-2, -1)
    if attn_bias is not None:
        attn = attn + attn_bias
    attn = attn.softmax(-1)
    attn = F.dropout(attn, p)
    # BHLL @ BHLC -> BHLC
    out = attn @ value
    # BHLC -> BLHC
    out = out.transpose(1, 2)
    return out

  • 7
    点赞
  • 11
    收藏
    觉得还不错? 一键收藏
  • 10
    评论
评论 10
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值