ALiBi线性偏置注意力

一、目录

1.公式
2.实现

二、实现

1.公式在这里插入图片描述
m 的取值公式:2^(-8/n) n为head 头数
参考:https://zhuanlan.zhihu.com/p/632780188
2. 实现
github: https://github.com/ofirpress/attention_with_linear_biases/blob/master/fairseq/models/transformer.py#L941

import math
import torch
class TransformerDecoder():
    def __init__( self, args):
        self.args=args
        self._future_mask = torch.empty(0)
        #求坡度m
        def get_slopes(n):
            def get_slopes_power_of_2(n):
                start = (2 ** (-2 ** -(math.log2(n) - 3)))
                ratio = start
                return [start * ratio ** i for i in range(n)]

            if math.log2(n).is_integer():
                return get_slopes_power_of_2(
                    n)  # In the paper, we only train models that have 2^a heads for some a. This function has
            else:  # some good properties that only occur when the input is a power of 2. To maintain that even
                closest_power_of_2 = 2 ** math.floor(
                    math.log2(n))  # when the number of heads is not a power of 2, we use this workaround.
                return get_slopes_power_of_2(closest_power_of_2) + get_slopes(2 * closest_power_of_2)[0::2][
                                                                   :n - closest_power_of_2]

        maxpos = args.tokens_per_sample       # max number of tokens per sequence  每个序列的最大令牌数,最大有效长度
        attn_heads = args.decoder_attention_heads
        self.slopes = torch.Tensor(get_slopes(attn_heads))

        self.alibi = self.slopes.unsqueeze(1).unsqueeze(1) * torch.arange(maxpos).unsqueeze(0).unsqueeze(0).expand(
            attn_heads, -1, -1)
        self.alibi = self.alibi.view(attn_heads, 1, maxpos)     #[head, 1, maxpos]
        self.alibi = self.alibi.repeat(args.max_tokens // maxpos, 1, 1)  # batch_size, 1, 1

    #目的:将alibi 矩阵进行数据类型转变
    def buffered_future_mask(self, tensor):
        dim = tensor.size(1)
        # self._future_mask.device != tensor.device is not working in TorchScript. This is a workaround.
        if (
                self._future_mask.size(0) == 0
                or (not self._future_mask.device == tensor.device)
                or self._future_mask.size(1) < self.args.tokens_per_sample
        ):
            #求下三角矩阵
            self._future_mask = torch.triu(
                fill_with_neg_inf(torch.zeros([self.args.tokens_per_sample, self.args.tokens_per_sample])), 1
            )
            self._future_mask = self._future_mask.unsqueeze(0) + self.alibi
        self._future_mask = self._future_mask.to(tensor)         #转变为tensor 一样的数据类型
        return self._future_mask[:tensor.shape[0] * self.args.decoder_attention_heads, :dim, :dim]


def fill_with_neg_inf(t):
    """FP16-compatible function that fills a tensor with -inf."""
    return t.float().fill_(float("-inf")).type_as(t)

if __name__ == '__main__':
    class A():
        tokens_per_sample=8                 #每个序列的最大令牌数
        decoder_attention_heads=8
        max_tokens=256                      #每个批次中的最大令牌数
    args=A()
    m=TransformerDecoder(args)

    input_embedding=torch.randn(size=(2,256,64))
    attention_mask=m.buffered_future_mask(input_embedding)   #m*矩阵
    print(attention_mask)

    # attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
    #
    # if attention_mask is not None:
    #     if q_len == 1:  # inference with cache
    #         if len(attention_mask.size()) == 4:
    #             attention_mask = attention_mask[:, :, -1:, :]
    #         else:
    #             attention_mask = attention_mask[:, -1:, :]
    #     attn_weights = attn_weights + attention_mask
    #     attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min))

  • 9
    点赞
  • 9
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值