20240223通过简单的 自注意力机制代码 理解 自注意力机制

想象你在读一本故事书,故事里提到了很多人物。当你读到某个人物的名字时,为了更好地理解这个人物在故事中的角色和重要性,你可能会回想起这个人物之前出现的情景和与其他人物的互动。自注意力机制在做的,其实就是类似的事情,但它是机器学习模型在做这个过程。

然后下面是一个自注意力机制代码,它的作用是识别sentence中每一个词,并输出这个词对句子中所有其他词的重要性和关联度。

import torch
import torch.nn as nn
import torch.nn.functional as F

class SelfAttention(nn.Module):
    def __init__(self, embed_size):
        super(SelfAttention, self).__init__()
        self.query = nn.Linear(embed_size, embed_size)
        self.key = nn.Linear(embed_size, embed_size)
        self.value = nn.Linear(embed_size, embed_size)
    
    def forward(self, x):
        Q = self.query(x)
        K = self.key(x)
        V = self.value(x)
        
        # 计算注意力得分
        attention_scores = torch.matmul(Q, K.transpose(-2, -1)) / torch.sqrt(torch.tensor(embed_size, dtype=torch.float32))
        attention_weights = F.softmax(attention_scores, dim=-1)
        
        # 加权求和
        output = torch.matmul(attention_weights, V)
        return output, attention_weights

# 假设的嵌入大小
embed_size = 4
# 初始化嵌入层和自注意力模块
embedding = nn.Embedding(10, embed_size)  # 假设词汇表大小为10,仅为示例
self_attention = SelfAttention(embed_size)

# 示例句子,用数字代替词汇
sentence = torch.tensor([1, 2, 3,1,2,3,4,5], dtype=torch.long)  # 假设句子由四个词组成,每个词用一个索引表示
sentence_embeddings = embedding(sentence)

# 应用自注意力
output, attn_weights = self_attention(sentence_embeddings.unsqueeze(0))  # 增加批次维度

# 打印注意力权重
print("Attention weights:\n", attn_weights.squeeze().detach().numpy())

注意这里的sentence用数字代替了词语,可以想像为1是I,2是love,3是you,4是very,5是much

以下面的输出结果

Attention weights:
 [[0.17980367 0.11659602 0.07916874 0.17980367 0.11659602 0.07916874
  0.1249155  0.12394763]
 [0.18425122 0.11723904 0.06853089 0.18425122 0.11723904 0.06853089
  0.14476162 0.1151961 ]
 [0.15084903 0.1105122  0.12573701 0.15084903 0.1105122  0.12573701
  0.06491332 0.16089015]
 [0.17980367 0.11659602 0.07916874 0.17980367 0.11659602 0.07916874
  0.1249155  0.12394763]
 [0.18425122 0.11723904 0.06853089 0.18425122 0.11723904 0.06853089
  0.14476162 0.1151961 ]
 [0.15084903 0.1105122  0.12573701 0.15084903 0.1105122  0.12573701
  0.06491332 0.16089015]
 [0.09749766 0.12622462 0.1572222  0.09749766 0.12622462 0.1572222
  0.09982932 0.1382817 ]
 [0.22876804 0.10456762 0.04837963 0.22876804 0.10456762 0.04837963
  0.12278523 0.11378422]]

第1行代表词语I对每一个其他词的重要性,比如对I的重要性是0.1798,对love的重要性是0.1165,对you的重要性是0.07916,以此类推

第2行代表词语love对每一个其他词的重要性

以此类推。

当然,这是非常简化的自注意力机制代码。比如注意到重复的数字重要性是相同的,但是实际上相同的词在前后的重要性通常是不同的。

我问了一下GPT,看来真的要用词语代替掉数字还要处理很多问题。

 另外,我听说通常会在每一个词的编码向量后面加上一个位置信息(它在句子里的位置)来标识不同位置的相同单词,这样可以给他们不同的重要性。

ViT(Vision Transformers)里也会给不同Patch对应的向量增加位置信息,这么说来我感觉ViT中的Patch之于图片可能就相当于单词之于句子。

  • 9
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值