概念
注意力机制(Attention Mechanism)是一种模拟人类将“注意力”集中在必要的信息上,它使模型能够专注于输入的某些部分而忽略其他部分。该机制在机器翻译、图像处理和自然语言处理等领域中发挥了重要作用。
在传统的神经网络中,所有的输入数据对模型的重要性是一样的。而在注意力机制中,模型会为每个输入分配一个权重,表示该输入的重要性。这个过程可以帮助模型更好地捕捉长距离依赖关系,尤其是在序列数据处理中。
计算规则
它需要三个指定的输入Q(query), K(key), V(value), 然后通过计算公式得到注意力的结果, 这个结果代表query在key和value作用下的注意力表示。
常见的注意力机制:自注意力机制、多头注意力机制
常见的计算公式:
Attention(Q, K, V) = Softmax(Linear([Q, K])) · V Attention(Q, K, V) = Softmax(sum(tanh(Linear([Q, K])))) · V Attention(Q, K, V) = Softmax((Q · K^T)/d) · V
作用
注意力机制大多数应用于编码器和解码器模型
在编码器端,相当于特征提取
在解码器端,提升信息存储的效果
实现
"""
使用第一个公式
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
class Attention(nn.Module):
def __init__(self, query_dim, key_dim, value_dim1, value_dim2, output_dim):
"""
初始化函数:
- query_dim: query 向量的最后一维大小
- key_dim: key 向量的最后一维大小
- value_dim1: value 向量的倒数第二维大小
- value_dim2: value 向量的倒数第一维大小
- output_dim: 输出向量的最后一维大小
"""
super(Attention, self).__init__()
self.query_dim = query_dim
self.key_dim = key_dim
self.value_dim1 = value_dim1
self.value_dim2 = value_dim2
self.output_dim = output_dim
# 初始化用于计算注意力权重的线性层
self.attention_layer = nn.Linear(self.query_dim + self.key_dim, value_dim1)
# 初始化用于计算最终输出的线性层
self.output_layer = nn.Linear(self.query_dim + value_dim2, output_dim)
def forward(self, query, key, value):
"""
前向传播函数:
输入参数:
- query: 形状为 (batch_size, query_dim) 的张量
- key: 形状为 (batch_size, key_dim) 的张量
- value: 形状为 (batch_size, value_dim1, value_dim2) 的张量
返回:
- output: 形状为 (1, batch_size, output_dim) 的张量
- attention_weights: 形状为 (batch_size, value_dim1) 的注意力权重张量
"""
# 1. 拼接 Q 和 K, 计算注意力权重
attention_weights = F.softmax(
self.attention_layer(torch.cat((query[0], key[0]), dim=1)), dim=1)
# 2. 使用注意力权重与 V 进行矩阵乘法,并与 Q 拼接
output = torch.cat(
(query[0], torch.bmm(attention_weights.unsqueeze(0), value)[0]), dim=1)
# 3. 通过线性层计算最终输出,并扩展维度
output = self.output_layer(output).unsqueeze(0)
return output, attention_weights
if __name__ == "__main__":
# 定义输入的维度
query_dim = 64
key_dim = 64
value_dim1 = 10
value_dim2 = 32
output_dim = 128
# 初始化 Attention 模块
attention = Attention(query_dim, key_dim, value_dim1, value_dim2, output_dim)
# 创建模拟的输入数据
batch_size = 1
# 随机生成 query, key 和 value 张量
query = torch.randn(batch_size, query_dim).unsqueeze(0) # 形状: (1, batch_size, query_dim)
key = torch.randn(batch_size, key_dim).unsqueeze(0) # 形状: (1, batch_size, key_dim)
value = torch.randn(batch_size, value_dim1, value_dim2) # 形状: (batch_size, value_dim1, value_dim2)
# 调用 Attention 模块进行前向传播
output, attention_weights = attention(query, key, value)
# 输出结果
print("Output shape:", output.shape) # 输出张量的形状
# print("Output:", output) # 输出张量的值
print("Attention Weights shape:", attention_weights.shape) # 注意力权重张量的形状
# print("Attention Weights:", attention_weights) # 注意力权重张量的值