新版本FasterTransformer的FUSED_MHA

关于 UNFUSED_PADDED_MHA VS FUSED_MHA

  • FUSED_MHA用了另一种kernel的执行方法(和添加链接描述相同,将在下一个section说明)
  • UNFUSED_PADDED 的 KERNELS执行代码在 src/fastertransformer/kernels/unfused_attention_kernels.cu
    在这里插入图片描述

在这里插入图片描述

enum class AttentionType {
    UNFUSED_MHA,
    UNFUSED_PADDED_MHA,
    FUSED_MHA,
    FUSED_PADDED_MHA
};

/* NOTE:
1. only swin-style relative position bias is supported currently
2. gpt-style (causal-mask) models support any-sequence-length fmha, so we don't need to call isValidSeqLen at run-time
3. bert/vit can also support any-seq-length fmha
*/
template<typename T>
AttentionType getAttentionType(size_t     size_per_head,
                               const int  sm,
                               const bool remove_padding,
                               const int  max_seq_len,
                               const bool is_fuse                          = true,
                               const bool with_swin_relative_position_bias = false,
                               const bool causal_mask                      = false)
{

    if (std::is_same<T, half>::value && is_fuse) {
        // Bert/Vit
        if (!causal_mask) {
            if (!with_swin_relative_position_bias
                && (((sm == kSM_70 || sm == kSM_72) && size_per_head == 64)
                    || ((sm == kSM_75 || sm == kSM_80 || sm == kSM_86)
                        && (size_per_head == 64 || size_per_head == 32)))) {
                return remove_padding ? AttentionType::FUSED_MHA : AttentionType::FUSED_PADDED_MHA;
            }
            else if (with_swin_relative_position_bias && (sm == kSM_75 || sm == kSM_80 || sm == kSM_86)
                     && max_seq_len <= 256 && size_per_head == 32) {
                return remove_padding ? AttentionType::FUSED_MHA : AttentionType::FUSED_PADDED_MHA;
            }
        }
        // GPT and its variants
        else {
           // FMHA_ENABLE only affects gpt-style models (causal-mask)
            char * fused_qkv = std::getenv("FMHA_ENABLE");
            if (fused_qkv != nullptr && std::string(fused_qkv) == "ON") {
                if ((sm == kSM_70 || sm == kSM_72 || sm == kSM_75 || sm == kSM_80 || sm == kSM_86 || sm == kSM_89)
                    && (size_per_head == 32 || size_per_head == 40 || size_per_head == 64 || size_per_head == 80
                        || size_per_head == 128 || size_per_head == 144 || size_per_head == 160 || size_per_head == 256)) {
                    return remove_padding ? AttentionType::FUSED_MHA : AttentionType::UNFUSED_PADDED_MHA;
                }
            }
        }
    }
  • 如果想执行FUSED_MHA,需要将参数设置如下:
    在这里插入图片描述

FUSED_MHA

在这里插入图片描述

所以有关核函数的定义调用等还在forward部分:

在这里插入图片描述
在这里插入图片描述

https://github1s.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/models/bert/Bert.cc#L494

在这里插入图片描述

调用了FusedAttentionLayer的传播函数

在这里插入图片描述

传播函数的融合部分

在这里插入图片描述

Dispatcher_fp16为指向MHARunner类型的指针
在这里插入图片描述

实际上通过 .reset()实现了多态:
在这里插入图片描述

最终调用pimpl->run
在这里插入图片描述

指针pimpl对应的内部类的定义在
https://github1s.com/NVIDIA/FasterTransformer/blob/main/3rdparty/trt_fused_multihead_attention/qkvToContext.cu#L62

在这里插入图片描述

在这里插入图片描述

在这里插入图片描述
在这里插入图片描述

在这里插入图片描述

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值