一 结构分析
1 FlashAttention类
主要的实现,在一个类FlashAttention中实现
这个类的主要结构:
class FlashAttention(torch.autograd.Function):
@staticmethod
def forward(ctx, q, k, v, causal, sm_scale, return_log_normalizer, return_total_attention):
@staticmethod
def backward(ctx, do, *ignored):
在这个类中调用的其他函数都不是类的内部的函数,类的内部的函数只有这两个
这两个函数是一定要这样实现的,是固定写法
2 调用接口
def attention(q, k, v, causal=False, sm_scale=None,
return_log_normalizer=False, return_total_attention=False,
):
return FlashAttention.apply(q, k, v, causal, sm_scale, return_log_normalizer, return_total_attention)
外部使用这个类,需要通过调用attention函数来实现
3 包管理(__init__.py)
try:
from ._version import version as __version__
from ._version import version_tuple
except ImportError:
__version__ = "0.0.0"
version_tuple = (0, 0, 0)
from flag_attn.piecewise import attention as piecewise_attention # noqa: F401
from flag_attn.flash import attention as flash_attention # noqa: F401
from flag_attn.split_kv import attention as flash_attention_split_kv # noqa: F401
from flag_attn.paged import attention as paged_attention # noqa: F401
from flag_attn import testing # noqa: F401
此文件的位置为:
通过源码可以发现,将paged.py中的attention函数导出为paged_attention