大模型相关代码2 -- 多头注意力机制【手撕基础模型】

注意力机制

  • 看见网上有要手撕注意力的面经,自己开始写,以为很简单,实际上自己菜的要死,遂写本博客。

公式

A t t e n t i o n ( Q , K , V ) = s o f t m a x ( Q K t d ) V Attention(Q,K,V) = softmax(\frac{QK^t}{\sqrt d})V Attention(Q,K,V)=softmax(d QKt)V
H e a d i = A t t e n t i o n ( Q W i Q , K W i K , V W i V ) Head_i = Attention(QW_i^Q,KW_i^K,VW_i^V) Headi=Attention(QWiQ,KWiK,VWiV)
M H A ( Q , K , V ) = C o n c a t ( H e a d 1 , … … , H e a d h ) W O MHA(Q,K,V) = Concat(Head_1,……,Head_h)W^O MHA(Q,K,V)=Concat(Head1,……,Headh)WO
O为最终输出变幻矩阵

代码

class MultiHeadAttention(nn.Module):
    def __init__(self, heads, d_model, dropout=0.1):
        super().__init__()
        self.d_model = d_model  # 模型的维度
        self.d_k = d_model // heads  # 每个头的维度
        self.h = heads  # 头的数量

        # 以下三个是线性层,用于处理Q(Query),K(Key),V(Value)
        self.q_linear = nn.Linear(d_model, d_model)
        self.v_linear = nn.Linear(d_model, d_model)
        self.k_linear = nn.Linear(d_model, d_model)

        self.dropout = nn.Dropout(dropout)  # Dropout层
        self.out = nn.Linear(d_model, d_model)  # 输出层

    def attention(self, q, k, v, d_k, mask=None, dropout=None):
        # torch.matmul是矩阵乘法,用于计算query和key的相似度
        scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k)

        if mask is not None:
            mask = mask.unsqueeze(1)  # 在第一个维度增加维度
            scores = scores.masked_fill(mask == 0, -1e9)  # 使用mask将不需要关注的位置设置为一个非常小的数

        # 对最后一个维度进行softmax运算,得到权重
        scores = F.softmax(scores, dim=-1)

        if dropout is not None:
            scores = dropout(scores)  # 应用dropout

        output = torch.matmul(scores, v)  # 将权重应用到value上
        return output

    def forward(self, q, k, v, mask=None):
        bs = q.size(0)  # 获取batch_size

        # 将Q, K, V通过线性层处理,然后分割成多个头
        k = self.k_linear(k).view(bs, -1, self.h, self.d_k) # batchsize * sentence_length * head * head_dim
        q = self.q_linear(q).view(bs, -1, self.h, self.d_k)
        v = self.v_linear(v).view(bs, -1, self.h, self.d_k)

        # batchsize * head * sentence_length * head_dim
        k = k.transpose(1, 2)
        q = q.transpose(1, 2)
        v = v.transpose(1, 2)
        # 调用attention函数计算输出
        scores = attention(q, k, v, self.d_k, mask, self.dropout)
        # 重新调整张量的形状,并通过最后一个线性层
        concat = scores.transpose(1, 2).contiguous().view(bs, -1, self.d_model)
        '''
        在PyTorch中,张量的存储方式是连续的(contiguous),也就是说张量的元素在内存中是按照一定顺序存储的。但是,在进行某些操作后,例如通过索引、转置、视图变换等操作,张量可能会变得不再连续。
        '''
        output = self.out(concat)  # 最终输出
        return output
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值