import torch
import torch.nn as nn
import torch.nn.functional as F
def SelfAttention(q, k, v, mask, dropout):
d_k = q.size(-1)
scores = torch.matmul(q, k.transpose(-2, -1)) / torch.sqrt(torch.tensor(d_k).float())
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e9)
scores = F.softmax(scores, dim=-1)
if dropout is not None:
scores = dropout(scores)
return torch.matmul(scores, v)
class MultiHeadAttention(nn.Module):
def __init__(self, n_head, d_model, dropout=0.1):
super().__init__()
assert d_model % n_head == 0
self.n_head = n_head
self.d_k = d_model // n_head
self.d_model = d_model
self.w_q = nn.Linear(d_model, d_model)
self.w_k = nn.Linear(d_model, d_model)
self.w_v = nn.Linear(d_model, d_model)
self.w_o = nn.Linear(d_model, d_model)
self.dropout = nn.Dropout(p=dropout)
def forward(self, q, k, v, mask=None):
n_batch = q.size(0)
q = self.w_q(q).view(n_batch, -1, self.n_head, self.d_k).transpose(1, 2)
k = self.w_k(k).view(n_batch, -1, self.n_head, self.d_k).transpose(1, 2)
v = self.w_v(v).view(n_batch, -1, self.n_head, self.d_k).transpose(1, 2)
if mask is not None:
mask = mask.unsqueeze(1)
attn_scores = SelfAttention(q, k, v, mask, self.dropout)
attn_scores = attn_scores.transpose(1, 2).contiguous().view(n_batch, -1, self.d_model)
return self.w_o(attn_scores)
def test_self_attention():
# 设置参数
batch_size = 2
seq_length = 3
d_k = 4
# 创建输入
q = torch.randn(batch_size, seq_length, d_k)
k = torch.randn(batch_size, seq_length, d_k)
v = torch.randn(batch_size, seq_length, d_k)
# 创建掩码(可选)
mask = torch.ones(batch_size, seq_length, seq_length).bool()
# 创建dropout层
dropout = nn.Dropout(p=0.1)
# 调用SelfAttention函数
output = SelfAttention(q, k, v, mask, dropout)
# 检查输出形状
expected_shape = (batch_size, seq_length, d_k)
assert output.shape == expected_shape, f"SelfAttention输出形状错误。期望 {expected_shape},得到 {output.shape}"
print("SelfAttention测试通过!输出形状正确。")
print(f"SelfAttention输出张量:\n{output}")
print(f"SelfAttention输出形状:{output.shape}")
def test_multi_head_attention():
# 设置参数
batch_size = 32
seq_length = 10
d_model = 512
n_head = 8
# 创建输入
q = torch.randn(batch_size, seq_length, d_model)
k = torch.randn(batch_size, seq_length, d_model)
v = torch.randn(batch_size, seq_length, d_model)
# 创建掩码(可选)
mask = torch.ones(batch_size, seq_length, seq_length).bool()
# 初始化 MultiHeadAttention 模型
mha = MultiHeadAttention(n_head, d_model)
# 前向传播
output = mha(q, k, v, mask)
# 检查输出形状
expected_shape = (batch_size, seq_length, d_model)
assert output.shape == expected_shape, f"MultiHeadAttention输出形状错误。期望 {expected_shape},得到 {output.shape}"
print("MultiHeadAttention测试通过!输出形状正确。")
print(f"MultiHeadAttention输出张量:\n{output}")
print(f"MultiHeadAttention输出形状:{output.shape}")
if __name__ == "__main__":
test_self_attention()
print("\n" + "="*50 + "\n")
test_multi_head_attention()
手撕面试注意力机制
最新推荐文章于 2024-09-12 11:12:51 发布