SAM-FAST:Accelerating Generative AI with PyTorch: Segment Anything, Fast基于官方PyTorch团队开发原生SAM提速8倍

今年上半年的时候最耀眼的光是属于大模型的,四月份的时候SAM横空出世带来了诸多惊喜,当时我写了对应的博文,感兴趣的话可以自行移步阅读即可,如下:

《Segment Anything Model (SAM)——卷起来了,那个号称分割一切的CV大模型他来了》

《Segment Anything Model (SAM)——分割一切,具有预测提示输入的图像分割实践》

大模型顾名思义核心在于大,极大的参数量带来的不可避免的问题就是难以训练、难以部署、难以在有限的成本条件下快速推理响应,一方面考虑加大硬件算力的投入,另一方面就是考虑再模型本身上做文章,这里本文的核心思想就是介绍官方PyTorch原生开发团队基于原生的开发模式对SAM重写构建更为高效的计算模型。

官方项目地址在这里,如下所示:

目前已经有超过4.2w的star量了,着实是很恐怖了。

就在上周PyTorch团队重写了SAM模型,目的就是为了尽可能释放出来纯原生PyTorch的性能来完成加速计算,据该团队最新的成果报告显示重写后的SAM模型性能提升了8倍,接下来我们来详细看下,官方博文地址在这里,如下所示:

这篇文章是一个多系列博客的第一部分,重点是如何使用纯原生PyTorch加速生成人工智能模型。我们很高兴能分享一系列新发布的PyTorch性能功能,以及如何将这些功能组合在一起的实际示例,看看我们能在多大程度上推动PyTorch原生性能。

正如在2023年PyTorch开发者大会上宣布的那样,PyTorch团队重写了Meta的Segment Anything(“SAM”)模型,使代码比原始实现快8倍,并且没有损失准确性,所有这些都使用了本地PyTorch优化。我们利用了PyTorch的一系列新功能:

      Torch.compile:PyTorch模型的编译器

      GPU量化:以降低的精度操作加速模型

     标度点积注意力(SDPA):记忆高效注意力实现

     半结构化(2:4)稀疏:GPU优化的稀疏内存格式

     嵌套张量:将大小不一致的数据批处理成一个张量,例如不同大小的图像。

    使用Triton自定义算子:使用Triton Python DSL编写GPU操作,并通过自定义算子注册将其轻松集成到PyTorch的各种组件中。

我们鼓励读者在Github上复制我们SAM实现中的粘贴代码,并在Github上向我们提问。官方同样开源了代码,项目地址在这里,如下所示:

通过我们最新发布的PyTorch原生功能,快速了解吞吐量的增加和内存开销的减少。基准测试在p4d.24xlarge实例(8x A100s)上运行。

SEGMENTANYTHING MODEL

SAM是一种用于生成可提示图像掩模的零样本视觉模型。

SAM架构[在其论文中描述]包括基于Transformer架构的多个提示和图像编码器。其中,我们测量了最小和最大的视觉转换器主干ViT-B和ViT-H的性能。为了简单起见,我们只显示ViT-B模型的轨迹。

OPTIMIZATIONS

下面我们讲述优化SAM的故事:分析、识别瓶颈,并在PyTorch中构建解决这些问题的新功能。在整个过程中,我们展示了PyTorch的新功能:torch.compile、SDPA、Triton内核、嵌套张量和半结构化稀疏性。以下部分是逐步构建的,以我们的SAM fast结束,现在可以在Github上使用。我们使用真实的内核和内存跟踪,使用完全PyTorch的本地工具来激励每个功能,并使用Perfetto UI可视化这些跟踪。

Baseline

我们的SAM基线是Facebook Research的未修改模型,使用float32数据类型和1的批量大小。在一些初始的预热之后,我们可以使用PyTorch Profiler查看内核跟踪:

我们注意到两个领域的优化时机已经成熟。

第一种是对aten::index的长调用,这是Tensor索引操作(例如[])产生的底层调用。而实际花费在aten::index上的GPU时间相对较低。aten::index正在启动两个内核,并且在这两个内核之间正在发生阻塞的cudaStreamSynchronize。这意味着CPU正在等待GPU完成处理,直到它启动第二个内核。要优化SAM,我们应该致力于消除导致空闲时间的阻塞GPU同步。

第二个是在矩阵乘法中花费在GPU上的大量时间(以上流7 7中的深绿色)。这在变压器中很常见。如果我们能减少GPU在矩阵乘法上花费的时间,我们就能显著加快SAM。

我们可以测量开箱即用SAM的吞吐量(img/s)和内存开销(GiB),以建立基线:

Bfloat16半精度(+GPU同步和批处理)Bfloat16 Half precision (+GPU syncs and batching)

为了解决在矩阵乘法中花费较少时间的第一个问题,我们可以转向bfloat16。Bfloat16是一种常用的半精度类型。通过降低每个参数的精度和激活,我们可以节省大量的计算时间和内存。随着参数精度的降低,验证端到端模型的准确性至关重要。

这里显示的是用半精度bfloat16替换填充数据类型的示例。代码在这里

除了简单地将model.to设置为(torch.bfloat16)之外,我们还需要更改一些假设为默认dtype的小地方。

现在,为了删除GPU同步,我们需要审计导致它们的操作。我们可以通过在GPU跟踪中搜索对cudaStreamSynchronize的调用来找到这些代码。事实上,我们找到了两个可以重写为无同步的位置。

具体来说,我们看到在SAM的图像编码器中,有一些变量充当坐标缩放器、q_coords和k_coords。这些都是在CPU上分配和处理的。但是,一旦这些变量用于在rel_pos_resize中进行索引,索引操作就会自动将这些变量移动到GPU。这种复制导致了我们在上面观察到的GPU同步。我们注意到SAM的提示编码器中的第二个索引调用:我们可以使用torch.where重写它,如上所示。

Kernel trace 内核跟踪

在应用这些更改之后,我们开始看到各个内核调用之间有相当长的时间。由于启动内核的GPU开销,这通常在小批量(此处为1)的情况下观察到。为了更深入地了解优化的实际领域,我们可以开始对批量大小为8的SAM推断进行评测:

从每个内核花费的时间来看,我们将SAM的大部分GPU时间都花在了elementwise内核和softmax操作上。有了这一点,我们现在看到矩阵乘法已经成为一个小得多的相对开销。

将GPU同步和bfloat16优化结合在一起,我们现在已将SAM性能提高了3倍

Torch.compile (+graph breaks and CUDA graphs)

当观察到大量的小操作时,比如上面介绍的elementwise内核,使用编译器来融合操作会有很大的好处。PyTorch最近发布的torch.compile通过以下方面进行了出色的优化:

      1、将诸如nn之类的操作序列融合在一起。LayerNorm或nn。GELU集成到一个名为和

      2、Epilogues:融合紧跟矩阵乘法内核的操作,以减少GPU内核调用的数量。

通过这些优化,我们减少了GPU全局内存往返次数,从而加快了推理速度。我们现在可以在SAM的图像编码器上尝试torch.compile。为了最大限度地提高性能,我们使用了一些高级编译技术,例如:

      1、使用torch.compile的最大自动调谐模式,可以使用自定义epilogue实现CUDA图形和特定形状的内核

      2、通过设置TORCH_LOGS=“graph_breaks,recompiles”,我们可以手动验证我们没有遇到图中断或重新编译。

      3、用零填充输入到编码器的一批图像可以确保编译接受静态形状,从而能够始终使用具有自定义后缀的形状特定优化内核,而无需重新编译。

predictor.model.image_encoder = \
    torch.compile(predictor.model.image_encoder, mode=use_compile)

Kernel Trace 内核跟踪

torch.compile运行得很好。我们推出了一个CUDA图,它在定时区域内占GPU时间的很大一部分。让我们再次运行我们的配置文件,看看GPU在特定内核中花费的时间百分比:

我们现在看到softmax在各种GEMM变体之后的时间中占了很大一部分。总之,我们观察到以下针对批量大小8及以上变化的测量结果。

SDPA:scaled_dot_product_attention

接下来,我们可以解决transformer 性能开销最常见的领域之一:注意力机制。原生的注意力实现在时间和记忆上随序列长度呈二次方分布。PyTorch的scaled_dot_product_attention操作建立在Flash attention、FlashAttentionV2和xFormer的内存高效注意力原理之上,可以显著加快GPU的注意力。该操作与torch.compile相结合,使我们能够在MultiheadAttention的变体中表达和融合一个通用模式。经过一小部分更改后,我们可以调整模型以使用scaled_dot_product_attention。

Kernel Trace 内核跟踪

我们现在可以看到,特别是内存高效注意力内核在GPU上占用了大量的计算时间:

使用PyTorch的原生scaled_dot_product_attetion,我们可以显著增加批量大小。我们现在观察以下批量大小32及以上变化的测量结果。

Triton:用于融合相对位置编码的自定义SDPA

暂时脱离推理吞吐量,我们开始分析整个SAM内存。在图像编码器中,我们看到了内存分配的显著峰值:

放大后,我们可以看到这种分配发生在add_decomposed_rel_pos中,如下所示:

这里的attn变量是两个较小张量的相加:形状(B,q_h,q_w,k_h,1)的rel_h和形状(B、q_h、q_w、1,k_w)的rel_w。

内存高效注意力内核(通过SDPA使用)在注意力偏差大小超过3.0GiB的情况下需要很长时间,这并不奇怪。如果我们不分配这个大的attn张量,而是将两个较小的rel_h和rel_w张量线程到SDPA中,并且只根据需要构建attn,我们预计会显著提高性能。

不幸的是,这不是一个微不足道的修改;SDPA内核经过高度优化,并使用CUDA编写。我们可以求助于Triton,以及他们关于FlashAttention实现的易于理解和使用的教程。经过大量挖掘,并与xFormer的Daniel Haziza密切合作,我们发现了一个输入形状的例子,在这个例子中,实现内核的融合版本相对简单。详细信息已添加到存储库中。令人惊讶的是,对于推理情况,这可以在不到350行的代码中完成。

这是一个用新内核直接构建的PyTorch扩展的好例子。

Kernel trace 内核跟踪

使用我们的定制定位Triton内核,我们观察到以下批量大小为32的测量结果。

NT:嵌套张量和批处理predict_tarch

我们在图像编码器上花了很多时间。这是有道理的,因为它占用了最多的计算时间。然而,在这一点上,它已经得到了很好的优化,并且花费最多时间的运营商将需要大量的额外投资来改进。

我们在掩码预测管道中发现了一个有趣的观察结果:对于我们拥有的每个图像,都有一个相关的大小、坐标和fg_labels张量。这些张量中的每一个都具有不同的批量大小。每个图像本身也具有不同的大小。这种数据表示看起来像Jagged data。使用PyTorch最近发布的NestedTensor,我们可以将数据管道批处理coords和fg_labels Tensor修改为单个Nested张量。这对于图像编码器之后的提示编码器和掩码解码器具有显著的性能优势。调用:

torch.nested.nested_tensor(data, dtype=dtype, layout=torch.jagged)

Kernel trace 内核追踪

我们现在可以看到,我们可以从CPU启动内核的速度比GPU处理的速度快得多,并且在我们的定时区域结束时等待GPU完成需要花费很长时间(cudaDeviceSynchronize)。我们也看不到GPU上内核之间有任何空闲时间(空白)。

使用嵌套张量,我们观察到批量大小为32及以上的变化的以下测量结果。

int8: quantization and approximating matmul

我们在上面的跟踪中注意到,现在在GEMM内核中花费了大量时间。我们已经进行了足够的优化,现在我们看到矩阵乘法在推理中比缩放的点积注意力占据了更多的时间。

在从fp32到bfloat16的早期学习的基础上,让我们更进一步,用int8量化模拟更低的精度。在量化方法方面,我们专注于动态量化,其中我们的模型观察层的可能输入和权重的范围,并细分可表达的int8范围,以均匀地“分散”观察值。最终,每个浮点输入将被映射到范围[-128127]中的单个整数。有关更多信息,请参阅PyTorch的量化教程

降低精度可以立即节省峰值内存,但要实现推理加速,我们必须通过SAM的操作充分利用int8。这需要建立一个高效的int8@int8矩阵乘法内核,以及从高精度转换到低精度(量化)以及从低精度反转回高精度(反量化)的投射逻辑。利用torch.compile的强大功能,我们可以将这些量化和反量化例程编译并融合为矩阵乘法的有效单核和表。由此产生的实现相当短,不到250行代码。有关API和用法的更多信息,请参阅pytorch labs/ao。

虽然在推理时量化模型时通常会看到一些精度回归,但SAM在精度损失最小的情况下对低精度推理特别稳健。添加了量化后,我们现在观察到批量大小32及以上变化的以下测量结果。

稀疏:半结构化(2:4)稀疏性

矩阵乘法仍然是我们的瓶颈。我们可以使用另一种经典的近似矩阵乘法的方法:稀疏化来参考模型加速策略。通过稀疏化我们的矩阵(即,将值清零),理论上我们可以使用更少的比特来存储权重和激活张量。我们决定将张量中的哪些权重设置为零的过程称为修剪。修剪背后的想法是,权重张量中的小权重对层的净输出贡献很小,通常是权重与激活的乘积。修剪掉较小的权重可以在不显著损失准确性的情况下潜在地减小模型大小。

修剪的方法多种多样,从完全非结构化(其中贪婪地修剪权重)到高度结构化(其中一次修剪张量的大个子分量)。方法的选择并非易事。虽然非结构化修剪在理论上可能对精度的影响最小,但GPU在乘以大的密集矩阵方面也非常高效,并且在稀疏状态下可能会遭受显著的性能下降。PyTorch最近支持的一种修剪方法试图达到一种平衡,称为半结构化(或2:4)稀疏性。这种稀疏存储将原始张量显著减少了50%,同时产生了可以利用高性能2:4 GPU内核的密集张量输出。请参阅下图以获取图示。

来自developer.nvidia.com/blog/exploiting-apere-structured-sparsity-with-csparselt

为了使用这种稀疏存储格式和相关的快速内核,我们需要修剪权重,使其符合格式的约束。我们在1乘4的区域中选择两个最小的权重进行修剪,以衡量性能与精度的权衡。很容易将权重从默认的PyTorch(“跨步”)布局更改为这种新的半结构化稀疏布局。要实现apply_s稀疏(模型),我们只需要32行Python代码:

import torch
from torch.sparse import to_sparse_semi_structured, SparseSemiStructuredTensor

# Sparsity helper functions
def apply_fake_sparsity(model):
    """
    This function simulates 2:4 sparsity on all linear layers in a model.
    It uses the torch.ao.pruning flow.
    """
    # torch.ao.pruning flow
    from torch.ao.pruning import WeightNormSparsifier
    sparse_config = []
    for name, mod in model.named_modules():
        if isinstance(mod, torch.nn.Linear):
            sparse_config.append({"tensor_fqn": f"{name}.weight"})

    sparsifier = WeightNormSparsifier(sparsity_level=1.0,
                                      sparse_block_shape=(1,4),
                                      zeros_per_block=2)
    sparsifier.prepare(model, sparse_config)
    sparsifier.step()

    sparsifier.step()
    sparsifier.squash_mask()


def apply_sparse(model):
    apply_fake_sparsity(model)
    for name, mod in model.named_modules():
        if isinstance(mod, torch.nn.Linear):
            mod.weight = torch.nn.Parameter(to_sparse_semi_structured(mod.weight))

在2:4稀疏性的情况下,我们观察到具有vit_b和批大小32的SAM的峰值性能:

结论

最后,我们很高兴宣布了迄今为止最快的Segment Anything实现。我们使用一系列新发布的功能,在不损失准确性的情况下,用纯PyTorch重写了Meta的原始SAM:

      1、Torch.compile PyTorch的原生JIT编译器,提供PyTorch操作的快速、自动化融合[教程]

      2、GPU量化以降低的精度操作加速模型[api]

      3、标度点积注意力(SDPA)注意力的一种新的、内存高效的实现[教程]

      4、半结构化(2:4)稀疏性用更少的比特来加速模型,以存储权重和激活[教程]

      5、嵌套张量高度优化、粗糙的阵列处理,适用于不均匀的批处理和图像大小[教程]

      6、Triton 内核。自定义GPU操作,通过Triton轻松构建和优化

有关如何重现这篇博客文章中呈现的数据的更多细节,请快速查看片段的实验文件夹

在我们的下一篇文章中,我们很高兴能与PyTorch原生创作的LLM分享类似的性能提升!

  • 7
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Together_CZ

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值