想象你在读一本故事书,故事里提到了很多人物。当你读到某个人物的名字时,为了更好地理解这个人物在故事中的角色和重要性,你可能会回想起这个人物之前出现的情景和与其他人物的互动。自注意力机制在做的,其实就是类似的事情,但它是机器学习模型在做这个过程。
然后下面是一个自注意力机制代码,它的作用是识别sentence中每一个词,并输出这个词对句子中所有其他词的重要性和关联度。
import torch
import torch.nn as nn
import torch.nn.functional as F
class SelfAttention(nn.Module):
def __init__(self, embed_size):
super(SelfAttention, self).__init__()
self.query = nn.Linear(embed_size, embed_size)
self.key = nn.Linear(embed_size, embed_size)
self.value = nn.Linear(embed_size, embed_size)
def forward(self, x):
Q = self.query(x)
K = self.key(x)
V = self.value(x)
# 计算注意力得分
attention_scores = torch.matmul(Q, K.transpose(-2, -1)) / torch.sqrt(torch.tensor(embed_size, dtype=torch.float32))
attention_weights = F.softmax(attention_scores, dim=-1)
# 加权求和
output = torch.matmul(attention_weights, V)
return output, attention_weights
# 假设的嵌入大小
embed_size = 4
# 初始化嵌入层和自注意力模块
embedding = nn.Embedding(10, embed_size) # 假设词汇表大小为10,仅为示例
self_attention = SelfAttention(embed_size)
# 示例句子,用数字代替词汇
sentence = torch.tensor([1, 2, 3,1,2,3,4,5], dtype=torch.long) # 假设句子由四个词组成,每个词用一个索引表示
sentence_embeddings = embedding(sentence)
# 应用自注意力
output, attn_weights = self_attention(sentence_embeddings.unsqueeze(0)) # 增加批次维度
# 打印注意力权重
print("Attention weights:\n", attn_weights.squeeze().detach().numpy())
注意这里的sentence用数字代替了词语,可以想像为1是I,2是love,3是you,4是very,5是much
以下面的输出结果
Attention weights:
[[0.17980367 0.11659602 0.07916874 0.17980367 0.11659602 0.07916874
0.1249155 0.12394763]
[0.18425122 0.11723904 0.06853089 0.18425122 0.11723904 0.06853089
0.14476162 0.1151961 ]
[0.15084903 0.1105122 0.12573701 0.15084903 0.1105122 0.12573701
0.06491332 0.16089015]
[0.17980367 0.11659602 0.07916874 0.17980367 0.11659602 0.07916874
0.1249155 0.12394763]
[0.18425122 0.11723904 0.06853089 0.18425122 0.11723904 0.06853089
0.14476162 0.1151961 ]
[0.15084903 0.1105122 0.12573701 0.15084903 0.1105122 0.12573701
0.06491332 0.16089015]
[0.09749766 0.12622462 0.1572222 0.09749766 0.12622462 0.1572222
0.09982932 0.1382817 ]
[0.22876804 0.10456762 0.04837963 0.22876804 0.10456762 0.04837963
0.12278523 0.11378422]]
第1行代表词语I对每一个其他词的重要性,比如对I的重要性是0.1798,对love的重要性是0.1165,对you的重要性是0.07916,以此类推
第2行代表词语love对每一个其他词的重要性
以此类推。
当然,这是非常简化的自注意力机制代码。比如注意到重复的数字重要性是相同的,但是实际上相同的词在前后的重要性通常是不同的。
我问了一下GPT,看来真的要用词语代替掉数字还要处理很多问题。
另外,我听说通常会在每一个词的编码向量后面加上一个位置信息(它在句子里的位置)来标识不同位置的相同单词,这样可以给他们不同的重要性。
ViT(Vision Transformers)里也会给不同Patch对应的向量增加位置信息,这么说来我感觉ViT中的Patch之于图片可能就相当于单词之于句子。