import torch
# 计算两个向量的注意力分数
def attention_score(query, key):
return torch.matmul(query, key.transpose(-2, -1))
# 计算注意力权重
def attention_weights(query, key, values):
score = attention_score(query, key)
weights = torch.softmax(score, dim=-1)
return torch.matmul(weights, values)
# 构建自注意力层
class SelfAttention(torch.nn.Module):
def __init__(self, input_size, hidden_size):
super(SelfAttention, self).__init__()
self.query = torch.nn.Linear(input_size, hidden_size)
self.key = torch.nn.Linear(input_size, hidden_size)
self.value = torch.nn.Linear(input_size, hidden_size)
self.output = torch.nn.Linear(hidden_size, input_size)
def forward(self, inputs):
query = self.query(inputs)
key = self.key(inputs)
value = self.value(inputs)
attention = attention_weights(query, key, value)
return self.output(attention)
测试用例
# 定义测试用的输入
inputs = torch.randn(4, 5, 8)
# 创建一个自注意力层
attention = SelfAttention(8, 8)
# 计算自注意力层的输出
outputs = attention(inputs)
# 打印输出的形状
print(outputs.shape) # 输出: (4, 5, 8)
在上面的示例代码中,输入张量的形状为 (4, 5, 8),其中 4、5、8 分别表示的含义如下:
4:批大小,即一次计算时处理的样本数量。在这个例子中,我们一次处理了 4 个样本。
5:序列长度,即每个样本中包含的元素数量。在这个例子中,每个样本由 5 个元素组成。
8:元素的维度,即每个元素的特征数量。在这个例子中,每个元素都由 8 个特征构成。
在实际应用中,这些维度的含义可能会有所不同。例如,在自然语言处理中,序列长度通常表示句子中的单词数量,元素维度表示单词的维度(例如词嵌入)。在图像处理中,序列长度可能表示图像中像素的数量,元素维度表示像素的通道数。总之,各个维度的含义取决于模型的具体应用场景。