一 结构分析
1 FlashAttention类
主要的实现,在一个类FlashAttention中实现

这个类的主要结构:
class FlashAttention(torch.autograd.Function):
@staticmethod
def forward(ctx, q, k, v, causal, sm_scale, return_log_normalizer, return_total_attention):
@staticmethod
def backward(ctx, do, *ignored):
在这个类中调用的其他函数都不是类的内部的函数,类的内部的函数只有这两个
这两个函数是一定要这样实现的,是固定写法
2 调用接口
def attention(q, k, v, causal=False, sm_scale=None,
return_log_normalizer=False, return_total_attention=False,
):
return FlashAttention.apply(q, k, v, causal, sm_scale, return_log_normalizer, return_total_attention)
外部使用这个类,需要通过调用attention函数来实现
3 包管理(__init__.py)
try:
from

最低0.47元/天 解锁文章
571

被折叠的 条评论
为什么被折叠?



