注意力机制:多头注意力(MultiHeadAttention+缩放点积注意力(scaled dot-product attention)代码详细实现+手动绘制的MultiHeadAttention网络

文章详细介绍了如何使用PyTorch实现多头注意力机制,包括transpose_qkv函数用于转换输入张量以便于多头并行计算,以及DotProductAttention和MultiHeadAttention类,展示了缩放点积注意力的结构和计算过程。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

一, 代码实例:

import torch
import math
from torch import nn
from d2l import torch as d2l
import matplotlib.pyplot as plt
# 定义transpose_qkv函数
def transpose_qkv(X, num_heads):
    """为了多注意力头的并行计算而变换形状"""
    # 输入X的形状:(batch_size, 查询或者“键-值”对的个数,num_hiddens)
    # 输出X的形状:(batch_size,查询或者“键-值”对的个数,num_heads,num_hiddens/num_heads)
    print('transpose_qkv:')
    print(X.shape)
    X = X.reshape(X.shape[0], X.shape[1], num_heads, 20)
    print(X.shape)

    # 输出X的形状:(batch_size,num_heads,查询或者“键-值”对的个数, num_hiddens/num_heads)
    X = X.permute(0, 2, 1, 3)
    print(X.shape)

    # 最终输出的形状:(batch_size*num_heads,查询或者“键-值”对的个数, num_hiddens/num_heads)
    return X.reshape(-1, X.shape[2], X.shape[3])

def transpose_output(X, num_heads):
    """逆转transpose_qkv函数的操作"""
    X = X.reshape(-1, num_heads, X.shape[1], X.shape[2])
    X = X.permute(0, 2, 1, 3)
    return X.reshape(X.shape[0], X.shape[1], -1)

def masked_softmax(X, valid_lens):
    """通过在最后一个轴上掩蔽元素来执行softmax操作"""
    # X:3D张量,valid_lens:1D或2D张量
    print('masked_softmax:', file=log)

    if valid_lens is None:
        return nn.functional.softmax(X, dim=-1)
    else:
        shape = X.shape
        if valid_lens.dim() == 1:
            valid_lens = torch.repeat_interleave(valid_lens, shape[1])
            print(valid_lens.shape, file=log)
            print(valid_lens, file=log)
        else:
            valid_lens = valid_lens.reshape(-1)
        # 最后一轴上被掩蔽的元素使用一个非常大的负值替换,从而其softmax输出为0
        X = d2l.sequence_mask(X.reshape(-1, shape[-1]), valid_lens, value=-1e6)
        return nn.functional.softmax(X.reshape(shape), dim=-1)

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值