Grouped Query Attention (GQA) PyTorch实现

个人在网上看到的实现好像都长得奇奇怪怪的,没有简洁的感觉,因此在这里给出一种易读的GQA实现方法:

import torch
import torch.nn as nn
import torch.nn.functional as F

class GroupedQueryAttention(nn.Module):
    def __init__(self, embed_dim, num_heads, num_groups):
        super().__init__()
        assert num_heads % num_groups == 0, "num_heads must be divisible by num_groups"
        self.num_heads = num_heads
        self.num_groups = num_groups
        self.head_dim = embed_dim // num_heads
        self.group_dim = self.num_groups * self.head_dim  # Correct: num_groups * head_dim
        self.scale = self.head_dim ** -0.5

        # Projections
        self.q_proj = nn.Linear(embed_dim, embed_dim)  # Query: full embed_dim for num_heads
        self.k_proj = nn.Linear(embed_dim, self.group_dim)  # Key: group_dim for num_groups
        self.v_proj = nn.Linear(embed_dim, self.group_dim)  # Value: group_dim for num_groups
        self.out_proj = nn.Linear(embed_dim, embed_dim)

    def forward(self, x):
        batch_size, seq_len, embed_dim = x.shape

        # Project inputs to q, k, v
        q = self.q_proj(x)  # Shape: (batch_size, seq_len, embed_dim)
        k = self.k_proj(x)  # Shape: (batch_size, seq_len, group_dim)
        v = self.v_proj(x)  # Shape: (batch_size, seq_len, group_dim)

        # Reshape query for multi-head attention
        q = q.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        # Shape: (batch_size, num_heads, seq_len, head_dim)

        # Reshape key and value for grouped attention
        k = k.view(batch_size, seq_len, self.num_groups, self.head_dim).transpose(1, 2)
        # Shape: (batch_size, num_groups, seq_len, head_dim)
        v = v.view(batch_size, seq_len, self.num_groups, self.head_dim).transpose(1, 2)
        # Shape: (batch_size, num_groups, seq_len, head_dim)

        # Repeat k and v to match the number of query heads
        heads_per_group = self.num_heads // self.num_groups
        k = k.repeat_interleave(heads_per_group, dim=1)
        # Shape: (batch_size, num_heads, seq_len, head_dim)
        v = v.repeat_interleave(heads_per_group, dim=1)
        # Shape: (batch_size, num_heads, seq_len, head_dim)

        # Compute attention scores
        scores = torch.matmul(q, k.transpose(-2, -1)) * self.scale
        # Shape: (batch_size, num_heads, seq_len, seq_len)
        attn = F.softmax(scores, dim=-1)
        out = torch.matmul(attn, v)  # Shape: (batch_size, num_heads, seq_len, head_dim)

        # Reshape and project output
        out = out.transpose(1, 2).reshape(batch_size, seq_len, embed_dim)
        out = self.out_proj(out)  # Shape: (batch_size, seq_len, embed_dim)
        return out

# Test the model
embed_dim = 64
num_heads = 8
num_groups = 4
model = GroupedQueryAttention(embed_dim, num_heads, num_groups)
x = torch.randn(2, 10, embed_dim)  # Input shape: (batch_size, seq_len, embed_dim)
output = model(x)
print(output.shape)  # Expected output: torch.Size([2, 10, 64])

为了读懂GQA,建议读者了解一下MQA的实现,这样顺着读下来会更顺手。

一旦读懂了MQA,GQA的实现思路几乎完全一样,只是多用了一个不太常用的函数tensor.repeat_interleave。关于这个函数,直接点击链接看笔者相关文章就行了,挺好懂的。

### GQA (Grouped Query Attention) 实现示例 GQA 是一种改进的多头注意力机制,通过将查询分为多个组来提高计算效率和模型性能。下面是一个基于 PyTorch 的简单实现: ```python import torch import torch.nn as nn import math class GroupedQueryAttention(nn.Module): def __init__(self, d_model, num_heads, n_groups=2, dropout=0.1): super(GroupedQueryAttention, self).__init__() assert d_model % num_heads == 0, "d_model must be divisible by num_heads" # 参数初始化 head_dim = d_model // num_heads hidden_size_per_group = head_dim * (num_heads // n_groups) # 定义线性变换矩阵 self.query_proj = nn.Linear(d_model, d_model) self.key_proj = nn.Linear(d_model, d_model) self.value_proj = nn.Linear(d_model, d_model) self.out_proj = nn.Linear(d_model, d_model) # Dropout 层 self.dropout = nn.Dropout(dropout) # 缩放因子 self.scale_factor = math.sqrt(head_dim) # 组数 self.n_groups = n_groups def forward(self, query, key, value, mask=None): batch_size, seq_len, _ = query.size() q = self.query_proj(query).view(batch_size, seq_len, self.n_groups, -1) k = self.key_proj(key).view(batch_size, seq_len, self.n_groups, -1) v = self.value_proj(value).view(batch_size, seq_len, self.n_groups, -1) scores = torch.einsum('bgid,bgjd->bigj', q, k) / self.scale_factor if mask is not None: scores = scores.masked_fill(mask.unsqueeze(1), float('-inf')) attn_weights = torch.softmax(scores, dim=-1) attn_weights = self.dropout(attn_weights) context = torch.einsum('bigj,bgjd->bgid', attn_weights, v) output = context.contiguous().view(batch_size, seq_len, -1) return self.out_proj(output) ``` 此代码实现了基本的分组查询注意力机制,其中 `n_groups` 控制查询被分成多少个独立处理的小组[^1]。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值