Python 实现神经网络模型的注意力机制
attention
是指神经网络模型的注意力机制,在 Python 中可以使用 PyTorch 框架来实现注意力机制attention
,实现注意力机制attention
前请确保已经按照了对应的torch
库。
如下是一个使用torch
库来实现注意力机制attention
的 Python 示例代码:
import torch
# 继承自 nn.Module 基类
class MultiHeadAttention(torch.nn.Module):
def __init__(self, n_heads, d_model, dropout=0.1):
super(MultiHeadAttention, self).__init__()
# 多头注意力头数
self.n_heads = n_heads
# 输入向量维度
self.d_model = d_model
# 每个头的维度
self.d_k = d_model // n_heads
# dropout 概率
self.dropout = torch.nn.Dropout(p=dropout)
# 初始化 Query、Key、Value 的权重矩阵
self.W_q = torch.nn.Linear(d_model, n_heads * self.d_k)
self.W_k = torch.nn.Linear(d_model, n_heads * self.d_k)
self.W_v = torch.nn.Linear(d_model, n_heads * self.d_k)
# 初始化输出的权重矩阵即输出向量的权重矩阵
self.W_o = torch.nn.Linear(n_heads * self.d_k, d_model)
def forward(self, x, mask=None):
# 输入 x 的维度为 [batch_size, seq_len, d_model]
batch_size, seq_len, d_model = x.size()
# 通过权重矩阵计算 Q、K、V
Q = self.W_q(x).view(batch_size, seq_len, self.n_heads, self.d_k)
K = self.W_k(x).view(batch_size, seq_len, self.n_heads, self.d_k)
V = self.W_v(x).view(batch_size, seq_len, self.n_heads, self.d_k)
# 交换维度以便于计算注意力权重
Q = Q.permute(0, 2, 1, 3).contiguous().view(batch_size * self.n_heads, seq_len, self.d_k)
K = K.permute(0, 2, 1, 3).contiguous().view(batch_size * self.n_heads, seq_len, self.d_k)
V = V.permute(0, 2, 1, 3).contiguous().view(batch_size * self.n_heads, seq_len, self.d_k)
# 计算注意力权重
scores = torch.bmm(Q, K.transpose(1, 2)) / torch.sqrt(torch.tensor(self.d_k, dtype=torch.float))
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e9)
attn_weights = torch.nn.Softmax(dim=-1)(scores)
attn_weights = self.dropout(attn_weights)
# 计算输出向量
attn_output = torch.bmm(attn_weights, V)
attn_output = attn_output.view(batch_size, self.n_heads, seq_len, self.d_k)
attn_output = attn_output.permute(0, 2, 1, 3).contiguous().view(batch_size, seq_len,
self.n_heads * self.d_k)
output = self.W_o(attn_output)
return output
# 定义输入向量
x = torch.randn(2, 10, 128)
# 定义注意力模块
attn = MultiHeadAttention(n_heads=8, d_model=128)
# 进行前向传播计算
output = attn(x)
# 打印输出向量的形状 torch.Size([2, 10, 128])
print(output.shape)
上述代码中定义了一个名为MultiHeadAttention
的类,它继承自torch
库的Module
类,MultiHeadAttention
类接收注意力头数n_heads
和向量d_model
作为参数来初始化权重矩阵,在forward
方法中,根据权重矩阵获取节点Q
、K
和V
,然后通过permute
方法交换节点维度计算注意力权重后调用Softmax
方法和dropout
方法来计算并输出注意力权重即张量,最终通过shape
方法来输出向量的形状。