几种不同的self-attention

几种不同的self-attention

  • Multi-head attention
  • Multi-query attention
  • grouped-query attention

在进行大模型的训练和推理中会大量的使用self-attention,在显存中需要保存self-attention中的query、key和value矩阵。Multi-head attention中每个头都有对应的query、key和value矩阵,因此会占用大量显存。而Multi-query attention中所有的头共用一个key和value矩阵来降低在模型训练和推理过程中大量占用显存的情况,不过这种方式可能会影响模型性能。grouped-query attention通过分组的方式,同一个组内共用一个key和value矩阵,当分组数与头数相同时即为Multi-head attention,当分组数为1时则为Multi-query attention。
以下是三种不同self-attention代码:

  1. Multi-head attention
import torch

# 增量式多头注意力机制
def MultiheadSelfAttentionIncremental():
	"""
	d_model:模型隐藏层大小
	b:批大小
	h:头的数量
	d_k:key的维度
	d_v:value的维度
	"""
	# 模型隐藏层512,批大小为32,头的数量为8,key和value为512//8
    d_model, b, h, d_k, d_v = 512, 32, 8, (512 // 8), (512 // 8)

    m = 5  # 假设已经缓存的token数量
    # 已经计算好的key和value矩阵,此处是假设已缓存了5个token的结果(随机初始化)
    prev_K = torch.rand(b, h, m, d_k)
    prev_V = torch.rand(b, h, m, d_v)

    X = torch.rand(b, d_model)  # Query
    M = torch.rand(b, d_model)  # Key and Value
    # q、k、v和输出的权重矩阵
    P_q = torch.rand(h, d_model, d_k)  # W_q
    P_k = torch.rand(h, d_model, d_k)  # W_k
    P_v = torch.rand(h, d_model, d_v)  # W_v
    P_o = torch.rand(h, d_model, d_v)  # W_o

    q = torch.einsum("bd,hdk->bhk", X, P_q)  # 多维线性代数数组操作,将从输入到Query
    new_K = torch.concat(
        [prev_K, torch.einsum("bd,hdk->bhk", M, P_k).unsqueeze(2)], axis=2
    )  # prev_K(批, 头, 已有token, key维度),通过torch.einsum生成新的token的key,将两个矩阵在已有token这个维度上上进行矩阵拼接
    new_V = torch.concat(
        [prev_V, torch.einsum("bd,hdv->bhv", M, P_v).unsqueeze(2)], axis=2
    )  
    # 进行softmax计算
    logits = torch.einsum("bhk,bhmk->bhm", q, new_K)  # 计算qk
    weights = torch.softmax(logits, dim=-1)
    O = torch.einsum("bhm,bhmv->bhv", weights, new_V)
    y = torch.einsum("bhv,hdv->bd", O, P_o)
    return y, new_K, new_V

if __name__ == "__main__":
	    print(MultiheadSelfAttentionIncremental())
  1. multi-query attention
import torch

# 增量式Multi-query attention
def MultiquerySelfAttentionIncremental():
	# 以下参数分别为模型隐藏层大小,批,头,key,value
    d, b, h, k, v = 512, 32, 8, (512 // 8), (512 // 8)

    m = 5  # 假设序列已有5个token
    # 初始化已有5个token的key和value 缓存
    prev_K = torch.rand(b, m, k)  # 由于multi-query attention中无论多少个头都只有一个key和value矩阵,因此比较multi-head attention中的代码少了头这个维度
    prev_V = torch.rand(b, m, v) 

    X = torch.rand(b, d)  # 随机初始化Query
    M = torch.rand(b, d)  # 随机初始化Key和Value
    # q、k、v和输出的权重矩阵
    P_q = torch.rand(h, d, k)  # W_q
    P_k = torch.rand(d, k)  # W_k
    P_v = torch.rand(d, v)  # W_v
    P_o = torch.rand(h, d, v)  # W_o

    q = torch.einsum("bd,hdk->bhk", X, P_q)
    K = torch.concat([prev_K, torch.einsum("bd,dk->bk", M, P_k).unsqueeze(1)], axis=1)
    V = torch.concat([prev_V, torch.einsum("bd,dv->bv", M, P_v).unsqueeze(1)], axis=1)
    logits = torch.einsum("bhk,bmk->bhm", q, K)
    weights = torch.softmax(logits, dim=-1)
    O = torch.einsum("bhm,bmv->bhv", weights, V)
    y = torch.einsum("bhv,hdv->bd", O, P_o)
    return y, K, V

if __name__ == "__main__":
	print(MultiquerySelfAttentionIncremental())
  1. grouped-query attention
"""
在grouped-query attention中
当组数与头数相同时则为multi-head attention
当组数为1时则为multi-query attention
"""
def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
    """torch.repeat_interleave(x, dim=2, repeats=n_rep)"""
    bs, slen, n_kv_heads, head_dim = x.shape
    if n_rep == 1: # MHA
        return x
    return ( # MQA / GQA
        x[:, :, :, None, :]
        .expand(bs, slen, n_kv_heads, n_rep, head_dim)
        .reshape(bs, slen, n_kv_heads * n_rep, head_dim)
    )


class Attention(nn.Module):
    def __init__(self, args: ModelArgs):
        super().__init__()
        self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
        model_parallel_size = fs_init.get_model_parallel_world_size()
        self.n_local_heads = args.n_heads // model_parallel_size
        self.n_local_kv_heads = self.n_kv_heads // model_parallel_size # 此处 
        self.n_rep = self.n_local_heads // self.n_local_kv_heads # 此处 几个组
        self.head_dim = args.dim // args.n_heads

        self.wq = ColumnParallelLinear(
            args.dim,
            args.n_heads * self.head_dim,
            bias=False,
            gather_output=False,
            init_method=lambda x: x,
        )
        self.wk = ColumnParallelLinear(
            args.dim,
            self.n_kv_heads * self.head_dim, # 初始化为单个组内的一份
            bias=False,
            gather_output=False,
            init_method=lambda x: x,
        )
        self.wv = ColumnParallelLinear(
            args.dim,
            self.n_kv_heads * self.head_dim, # # 初始化为单个组内的一份
            bias=False,
            gather_output=False,
            init_method=lambda x: x,
        )
        self.wo = RowParallelLinear(
            args.n_heads * self.head_dim,
            args.dim,
            bias=False,
            input_is_parallel=True,
            init_method=lambda x: x,
        )

        self.cache_k = torch.zeros(
            (
                args.max_batch_size,
                args.max_seq_len,
                self.n_local_kv_heads,
                self.head_dim,
            )
        ).cuda()
        self.cache_v = torch.zeros(
            (
                args.max_batch_size,
                args.max_seq_len,
                self.n_local_kv_heads,
                self.head_dim,
            )
        ).cuda()

    def forward(
        self,
        x: torch.Tensor,
        start_pos: int,
        freqs_cis: torch.Tensor,
        mask: Optional[torch.Tensor],
    ):
        bsz, seqlen, _ = x.shape
        xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)

        xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
        xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
        xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)

        xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)

        self.cache_k = self.cache_k.to(xq)
        self.cache_v = self.cache_v.to(xq)

        self.cache_k[:bsz, start_pos : start_pos + seqlen] = xk
        self.cache_v[:bsz, start_pos : start_pos + seqlen] = xv

        keys = self.cache_k[:bsz, : start_pos + seqlen]
        values = self.cache_v[:bsz, : start_pos + seqlen]

        # repeat k/v heads if n_kv_heads < n_heads # 单个组扩展为完整head
        keys = repeat_kv(keys, self.n_rep)  # (bs, seqlen, n_local_heads, head_dim)
        values = repeat_kv(values, self.n_rep)  # (bs, seqlen, n_local_heads, head_dim)

        xq = xq.transpose(1, 2)  # (bs, n_local_heads, seqlen, head_dim)
        keys = keys.transpose(1, 2)
        values = values.transpose(1, 2)
        scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim)
        if mask is not None:
            scores = scores + mask  # (bs, n_local_heads, seqlen, cache_len + seqlen)
        scores = F.softmax(scores.float(), dim=-1).type_as(xq)
        output = torch.matmul(scores, values)  # (bs, n_local_heads, seqlen, head_dim)
        output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
        return self.wo(output)
  • 10
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值