pytorch加速-SDPA缩放的点乘注意力

转:https://mp.weixin.qq.com/s/pkAFDnxYYFlOfY1q2MLsfQ

目前Transformer已经成为各个领域(文本,图像,语音)最常用的模型架构,PyTorch 2.0也进一步对Transformer模块进行了优化,以支持Tranformer结构模型的高效训练和推理。具体来说,PyTorch 2.0在torch.nn.functional中引入了一个新的函数:torch.nn.functional.scaled_dot_product_attention,这里简称为SDPA,这个SDPA背后实现了高性能的kernels,所以你可以直接使用SDPA来进行训练和推理的减速。

这里我们可以简单看一下这个SDPA这个函数的签名和参数说明:

torch.nn.functional.scaled_dot_product_attention(
    query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None
) → Tensor:
"""
Args:
query (Tensor): Query tensor; shape :math:`(N, ..., L, E)`.
key (Tensor): Key tensor; shape :math:`(N, ..., S, E)`.
value (Tensor): Value tensor; shape :math:`(N, ..., S, Ev)`.
attn_mask (optional Tensor): Attention mask; shape :math:`(N, ..., L, S)`. Two types of masks are supported.
    A boolean mask where a value of True indicates that the element *should* take part in attention.
    A float mask of the same type as query, key, value that is added to the attention score.
    dropout_p (float): Dropout probability; if greater than 0.0, dropout is applied
is_causal (bool): If true, assumes causal attention masking and errors if both attn_mask and is_causal
are set.
    scale (optional float): Scaling factor applied prior to softmax. If None, the default value is set
to :math:`\frac{1}{\sqrt{E}}`.
Returns:
output (Tensor): Attention output; shape :math:`(N, ..., L, Ev)`.
"""
pass

SDPA实现了attention模块最核心的部分(缩放的点乘注意力),这个函数等价于以下代码:

scale_factor = 1 / math.sqrt(Q.size(-1)) if scale is None else scale
attn_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0) if is_causal else attn_mask
attn_mask = attn_mask.masked_fill(not attn_mask, -float('inf')) if attn_mask.dtype==torch.bool else attn_mask
attn_weight = torch.softmax((Q @ K.transpose(-2, -1) * scale_factor) + attn_mask, dim=-1)
attn_weight = torch.dropout(attn_weight, dropout_p)
return attn_weight @ V

这个函数也已经嵌入了PyTorch现有的Transformer API中,这就是说你在构建模型时直接使用torch.nn.MultiheadAttention 和torch.nn.TransformerEncoderLayer模块就可以看到SDPA带来的性能加速。当然,如果你需要定制化功能,那么你可以直接用这个函数来创建自己的attention模块。

SDPA之所以能带来性能的加速,主要是它背后已经实现了优化的kernels,目前SDPA支持三种kernels:

  • sdpa_flash: FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness
  • sdpa_mem_eff: Memory-Efficient Attention
  • sdpa_math: A PyTorch implementation defined in C++

其中sdpa_flash支持在SM80+架构的GPUs上使用FP16精度训练和推理,而sdpa_mem_eff支持在大部分GPUs上采用FP16和FP32精度训练和推理。如果上述两个kernel都不支持的话,那么就只能采用sdpa_math了,它是直接基于C++的通用实现。默认情况下,这三个kernel都是开启的,当你调用SDPA时,它将根据你的输入选择一个最优的kernel来进行执行。

大部分情况下,我们不需要关注背后具体所选择的kernel,因为它背后已经做了最优的选择。但是如果你想显式控制所使用的kernel,那么可以采用torch.backends.cuda.sdp_kernel()来关闭具体的kernels,它是一个上下文管理器,比如我们要关闭sdpa_math,那么可以这样调用:

query = torch.rand(32, 8, 128, 64, dtype=torch.float16, device="cuda")
key = torch.rand(32, 8, 128, 64, dtype=torch.float16, device="cuda")
value = torch.rand(32, 8, 128, 64, dtype=torch.float16, device="cuda")
with torch.backends.cuda.sdp_kernel(enable_math=False):
    F.scaled_dot_product_attention(query, key, value)

由于sdpa_math被关闭,那么此时系统只能从sdpa_flash和sdpa_mem_eff这个两个kernel进行选择了。当你关闭两个kernel,那么就等同于直接选择使用剩下的那个kernel来进行实现了。

这里我们可以使用sdp_kernel这个工具来比较不同的kernels下的计算时间,具体的代码如下:

import torch
import torch.utils.benchmark as benchmark
from torch.backends.cuda import sdp_kernel, SDPBackend
import torch.nn.functional as F

# Lets define a helpful benchmarking function:
def benchmark_torch_function_in_microseconds(f, *args, **kwargs):
    t0 = benchmark.Timer(
        stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f}
    )
    return t0.blocked_autorange().mean * 1e6

# Lets define the hyper-parameters of our input
batch_size = 32
max_sequence_len = 1024
num_heads = 32
embed_dimension = 32

dtype = torch.float16
device = "cuda"

query = torch.rand(batch_size, num_heads, max_sequence_len, embed_dimension, device=device, dtype=dtype)
key = torch.rand(batch_size, num_heads, max_sequence_len, embed_dimension, device=device, dtype=dtype)
value = torch.rand(batch_size, num_heads, max_sequence_len, embed_dimension, device=device, dtype=dtype)

print(f"The default implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds")

# Lets explore the speed of each of the 3 implementations

# Helpful arg mapper
backend_map = {
    SDPBackend.MATH: {"enable_math": True, "enable_flash": False, "enable_mem_efficient": False},
    SDPBackend.FLASH_ATTENTION: {"enable_math": False, "enable_flash": True, "enable_mem_efficient": False},
    SDPBackend.EFFICIENT_ATTENTION: {
        "enable_math": False, "enable_flash": False, "enable_mem_efficient": True}
}

with sdp_kernel(**backend_map[SDPBackend.MATH]):
    print(f"The math implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds")

with sdp_kernel(**backend_map[SDPBackend.FLASH_ATTENTION]):
    try:
        print(f"The flash attention implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds")
    except RuntimeError:
        print("FlashAttention is not supported. See warnings for reasons.")

with sdp_kernel(**backend_map[SDPBackend.EFFICIENT_ATTENTION]):
    try:
        print(f"The memory efficient implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds")
    except RuntimeError:
        print("EfficientAttention is not supported. See warnings for reasons.")

在V100机器上的运行结果如下所示:

The default implementation runs in 6569.854 microseconds

The math implementation runs in 16091.686 microseconds

<timeit-src>:6: UserWarning: Memory efficient kernel not used because: (Triggered internally at ../aten/src/ATen/native/transformers/cuda/sdp_utils.h:527.)

<timeit-src>:6: UserWarning: Memory Efficient attention has been runtime disabled. (Triggered internally at ../aten/src/ATen/native/transformers/cuda/sdp_utils.h:338.)

<timeit-src>:6: UserWarning: Flash attention kernel not used because: (Triggered internally at ../aten/src/ATen/native/transformers/cuda/sdp_utils.h:529.)

<timeit-src>:6: UserWarning: Flash attention only supports sm75 and sm8x gpu architectures. Attempting to run on a sm 7.0 gpu. (Triggered internally at ../aten/src/ATen/native/transformers/cuda/sdp_utils.h:352.)

FlashAttention is not supported. See warnings for reasons.

The memory efficient implementation runs in 6595.339 microseconds

V100卡属于sm 7.0,不支持Flash attention,但是我们可以看到默认采用的kernel是sdpd_mem_eff,它相比sdpd_math,速度提升非常明显(6ms vs 16ms)。当我们把机器换成A100后,运行结果如下所示:

The default implementation runs in 2831.521 microseconds
The math implementation runs in 7001.696 microseconds
The flash attention implementation runs in 2829.635 microseconds
The memory efficient implementation runs in 3011.410 microseconds

A100卡上是支持Flash attention,而且默认的实现方式是sdpa_flash,此时运行时间最短,A100比V100快了2倍多。

最后,我们再来看一下具体的实例,那就是基于SDPA对diffusers中的stable diffusion进行加速,目前diffusers中已经实现了基于scaled_dot_product_attention的AttnProcessor2_0:

class AttnProcessor2_0:
    def __init__(self):
        if not hasattr(F, "scaled_dot_product_attention"):
            raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")

    def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None):
        batch_size, sequence_length, _ = (
            hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
        )
        inner_dim = hidden_states.shape[-1]

        if attention_mask is not None:
            attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
            # scaled_dot_product_attention expects attention_mask shape to be
            # (batch, heads, source_length, target_length)
            attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])

        query = attn.to_q(hidden_states)

        if encoder_hidden_states is None:
            encoder_hidden_states = hidden_states
        elif attn.cross_attention_norm:
            encoder_hidden_states = attn.norm_cross(encoder_hidden_states)

        key = attn.to_k(encoder_hidden_states)
        value = attn.to_v(encoder_hidden_states)

        head_dim = inner_dim // attn.heads
        query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
        key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
        value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)

        # the output of sdp = (batch, num_heads, seq_len, head_dim)
        # TODO: add support for attn.scale when we move to Torch 2.1
        hidden_states = F.scaled_dot_product_attention(
            query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
        )

        hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
        hidden_states = hidden_states.to(query.dtype)

        # linear proj
        hidden_states = attn.to_out[0](hidden_states)
        # dropout
        hidden_states = attn.to_out[1](hidden_states)
        return hidden_states

以stable diffusion 1.5为例,首先我们将attention processor设置为默认的CrossAttnProcessor:

import torch
from diffusers import StableDiffusionPipeline
from diffusers.models.cross_attention import AttnProcessor2_0, CrossAttnProcessor

pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16).to("cuda")
pipe.unet.set_attn_processor(CrossAttnProcessor())

prompt = "a photo of an astronaut riding a horse on mars"
image = pipe(prompt).images[0]

这里在V100上运行的时间大约是3.6s(A100下是1.9s),显存最大占用约5.9GB。然后,我们将attention processor替换为AttnProcessor2_0:

pipe.unet.set_attn_processor(CAttnProcessor2_0())

加速后的运行时间大约是3s(A100下是1.6s),显存最大占用为4.7GB,可以看到我们不仅实现了加速,而且显存消耗也减少了。

另外,PyTorch 2.0也引入了torch.compile()来对模型进行加速,这里我们也可以在SDPA的基础上使用它来进一步来加速:

import torch
from diffusers import StableDiffusionPipeline
from diffusers.models.cross_attention import AttnProcessor2_0, CrossAttnProcessor

pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16).to(
    "cuda"
)
pipe.unet.set_attn_processor(AttnProcessor2_0()) # 其实默认会采用这个
pipe.unet = torch.compile(pipe.unet)

batch_size = 8
prompt = "A photo of an astronaut riding a horse on marse."
images = pipe(prompt, num_inference_steps=steps, num_images_per_prompt=batch_size).images

这里我在batch_size=8下,跑出来运行时间大约是16s(A100下是6.6s),而只采用SDPA的版本运行时间约17s(A100下是7.3s),还是有一定的加速效果的(不过V100相比A100还是太弱了)。

注意,我们这里的比较并不是严谨的,其实PyTorch官方也已经进行了系统的评测,具体可以见博客Accelerated Diffusers with PyTorch 2.0。

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值