自注意力机学习

自注意力机制的核心概念

1. Query, Key 和 Value
  • Query(查询向量):可以看作是你当前在关注的输入项。假设你正在阅读一段文字,这就像你当前在读的句子。

  • Key(键向量):表示其他所有输入项的标识或特征。这就像你在书中已经读过的所有句子的摘要或要点。

  • Value(值向量):是与每个Key相关联的具体信息或内容。就像这些句子带来的详细信息。

现实比喻
想象你在图书馆寻找一本特定的书(Query),书架上有很多书,每本书都有一个书名(Key)。根据书名(Key)匹配你的查询(Query),你从合适的书中获取详细内容(Value)。

2. 点积注意力(Dot-Product Attention)

这是计算Query和Key之间相关性的方式。我们通过计算Query和Key的点积来确定它们的关系强度。

比喻
就像在图书馆,你有一本书的部分标题(Query),你对比书架上所有书的书名(Key),看哪个书名最接近你的标题,然后选出最相关的书(Value)。

3. 缩放(Scaling)

为了防止Query和Key之间的点积结果太大导致数值不稳定,我们将结果除以一个常数——通常是Key向量的维度的平方根。这使得计算更加稳定。

比喻
假设你在测试你的记忆力,如果你直接用高分数衡量,可能会出现极端值。所以你需要调整分数范围,使得评估更合理和稳定。

4. Softmax 归一化

Softmax函数将一组数值转换为概率分布,使得它们的总和为1。这意味着每个单词的注意力权重表示它对当前处理单词的重要性。

比喻
就像你在评分不同的书,Softmax就像把所有的分数转换成百分比,这样你可以看到每本书相对于其他书的重要性。

自注意力机制的工作流程

让我们更详细地看看自注意力机制是如何一步一步工作的:

  1. 生成 Query, Key 和 Value 向量

    我们首先通过线性变换将输入序列的每个单词转换成三个不同的向量:Query, Key 和 Value。

    query = W_q * input
    key = W_k * input
    value = W_v * input
    

    比喻:这是把每个单词变成三个不同的代表,就像给每个单词生成了三个不同的标签,用于不同的目的(查询、匹配和提供信息)。

  2. 计算注意力权重

    通过计算Query和Key的点积,我们得到它们之间的相关性得分。然后,我们将这些得分除以 d k \sqrt{d_k} dk 进行缩放,最后应用Softmax函数来得到权重。

    # 计算点积
    scores = query.dot(key.T) / sqrt(d_k)
    # 使用Softmax函数归一化
    attention_weights = softmax(scores)
    

    比喻:这就像你比较当前正在读的句子(Query)和你已经读过的所有句子(Key),然后根据它们的相似程度打分。接着,你将这些分数标准化,使它们总和为1,表示每个句子的重要性百分比。

  3. 加权求和 Value 向量

    我们将Value向量按照注意力权重进行加权求和,这样每个Value对最终输出的贡献由它的重要性决定。

    # 计算加权的Value
    output = sum(attention_weights * value)
    

    比喻:就像你根据每本书的重要性百分比(注意力权重),从每本书中提取一定量的信息(Value),最终形成你对整个图书馆信息的理解。

示例和实际应用

假设你在处理一句话“我喜欢吃苹果,因为苹果很甜”:

  1. Query, Key, Value

    • Query:当前处理的词是“苹果”。
    • Key:句子中的所有单词的表示,如“我”,“喜欢”,“吃”,“苹果”,“因为”,“很”,“甜”。
    • Value:这些单词的具体信息,比如它们的词义或上下文信息。
  2. 点积注意力

    • 你在评估“苹果”和句子中其他词的关系,比如“苹果”与“甜”的关系就很重要,而与“我”关系可能不大。
  3. Softmax 归一化

    • 将关系得分转化为一个概率分布,表示每个单词对当前词“苹果”的重要性。
  4. 加权求和

    • 最后,根据重要性权重,从每个单词中提取信息,生成“苹果”的最终表示,这样“苹果”就包含了它和“甜”的关系。

自注意力机制代码示例

import torch
import torch.nn as nn
import torch.nn.functional as F

class SelfAttention(nn.Module):
    def __init__(self, embed_size, heads):
        super(SelfAttention, self).__init__()
        self.embed_size = embed_size
        self.heads = heads
        self.head_dim = embed_size // heads

        assert (
            self.head_dim * heads == embed_size
        ), "Embedding size needs to be divisible by heads"

        self.values = nn.Linear(self.head_dim, embed_size, bias=False)
        self.keys = nn.Linear(self.head_dim, embed_size, bias=False)
        self.queries = nn.Linear(self.head_dim, embed_size, bias=False)
        self.fc_out = nn.Linear(embed_size, embed_size)

    def forward(self, values, keys, query, mask):
        N = query.shape[0]
        value_len, key_len, query_len = values.shape[1], keys.shape[1], query.shape[1]

        # 1. 生成 Query, Key 和 Value 向量
        values = values.reshape(N, value_len, self.heads, self.head_dim)
        keys = keys.reshape(N, key_len, self.heads, self.head_dim)
        queries = query.reshape(N, query_len, self.heads, self.head_dim)

        # 2. 计算注意力权重
        energy = torch.einsum("nqhd,nkhd->nhqk", [queries, keys])
        if mask is not None:
            energy = energy.masked_fill(mask == 0, float("-1e20"))

        attention = torch.softmax(energy / (self.embed_size ** (1 / 2)), dim=3)

        # 3. 加权求和 Value 向量
        out = torch.einsum("nhql,nlhd->nqhd", [attention, values]).reshape(
            N, query_len, self.embed_size
        )

        out = self.fc_out(out)
        return out

关键概念总结

  1. 自注意力机制:允许模型在处理一个输入时,同时关注到整个输入序列中的所有其他输入。提高了捕捉长距离依赖关系的能力。

  2. Query, Key 和 Value:分别代表当前处理的焦点、其他输入的标识和它们携带的信息。

  3. 点积注意力:通过计算Query和Key的相似性来确定它们之间的关系强度。

  4. 缩放:对点积结果进行调整,防止数值过大导致计算不稳定。

  5. Softmax 归一化:将相似性得分转化为概率分布,表示每个输入的重要性。

通过这些步骤,自注意力机制能够帮助模型在处理每一个输入时同时考虑整个序列,从而更好地理解上下文和词语之间的关系。

  • 14
    点赞
  • 9
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值