triton之flag-attention源码分析

一 结构分析

1 FlashAttention类

主要的实现,在一个类FlashAttention中实现

 这个类的主要结构:

class FlashAttention(torch.autograd.Function):
    @staticmethod
    def forward(ctx, q, k, v, causal, sm_scale, return_log_normalizer, return_total_attention):
    @staticmethod
    def backward(ctx, do, *ignored):

在这个类中调用的其他函数都不是类的内部的函数,类的内部的函数只有这两个 

这两个函数是一定要这样实现的,是固定写法

2 调用接口

def attention(q, k, v, causal=False, sm_scale=None,
              return_log_normalizer=False, return_total_attention=False,
):

    return FlashAttention.apply(q, k, v, causal, sm_scale, return_log_normalizer, return_total_attention)

外部使用这个类,需要通过调用attention函数来实现

3 包管理(__init__.py)

try:
    from ._version import version as __version__
    from ._version import version_tuple
except ImportError:
    __version__ = "0.0.0"
    version_tuple = (0, 0, 0)


from flag_attn.piecewise import attention as piecewise_attention # noqa: F401
from flag_attn.flash import attention as flash_attention # noqa: F401
from flag_attn.split_kv import attention as flash_attention_split_kv # noqa: F401
from flag_attn.paged import attention as paged_attention # noqa: F401

from flag_attn import testing # noqa: F401

 此文件的位置为:

通过源码可以发现,将paged.py中的attention函数导出为paged_attention 

  • 3
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

youzjuer

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值