PyTorch中torch.nn.MultiheadAttention()的实现(一维情况下)

该代码示例展示了如何在PyTorch中复现MultiheadAttention层,包括设置随机种子、进行QKV的线性变换、计算注意力权重以及输出映射的过程。通过Softmax函数计算注意力权重,并完成张量的乘法操作。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

import torch
import torch.nn as nn
import numpy as np


# TODO MHA
def setup_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    torch.backends.cudnn.deterministic = True


# 设置随机数种子
setup_seed(20)

Q = torch.tensor([[1]], dtype=torch.float32)  # [2, 3, 4]
K = torch.tensor([[3]], dtype=torch.float32)  # [2, 5, 4]
V = torch.tensor([[5]], dtype=torch.float32)  # [2, 5, 4]

multiHead = nn.MultiheadAttention(1, 1)
att_o, att_o_w = multiHead(Q, K, V)

################################

# 复现 Multi-head Attention
w = multiHead.in_proj_weight
b = multiHead.in_proj_bias
w_o = multiHead.out_proj.weight
b_o = multiHead.out_proj.bias

w_q, w_k, w_v = w.chunk(3)
b_q, b_k, b_v = b.chunk(3)

# Q、K、V的映射
q = Q @ w_q + b_q
k = K @ w_k + b_k
v = V @ w_v + b_v
dk = q.shape[-1]
# 注意力权重的计算
softmax_2 = torch.nn.Softmax(dim=-1)
att_o_w2 = softmax_2(q @ k.transpose(-2, -1) / np.sqrt(dk))
# 输出
out = att_o_w * v
# 输出映射
att_o2 = out @ w_o + b_o
print(att_o, att_o_w)
print(att_o2, att_o_w2)
pass

输出结果

tensor([[-0.4038]], grad_fn=<SqueezeBackward1>) tensor([[1.]], grad_fn=<SqueezeBackward1>)
tensor([[-0.4038]], grad_fn=<AddBackward0>) tensor([[1.]], grad_fn=<SoftmaxBackward0>)

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值