注意力机制

概念

        注意力机制(Attention Mechanism)是一种模拟人类将“注意力”集中在必要的信息上,它使模型能够专注于输入的某些部分而忽略其他部分。该机制在机器翻译、图像处理和自然语言处理等领域中发挥了重要作用。

        在传统的神经网络中,所有的输入数据对模型的重要性是一样的。而在注意力机制中,模型会为每个输入分配一个权重,表示该输入的重要性。这个过程可以帮助模型更好地捕捉长距离依赖关系,尤其是在序列数据处理中。

计算规则

        它需要三个指定的输入Q(query), K(key), V(value), 然后通过计算公式得到注意力的结果, 这个结果代表query在key和value作用下的注意力表示。

        常见的注意力机制:自注意力机制、多头注意力机制

        常见的计算公式:

Attention(Q, K, V) = Softmax(Linear([Q, K])) · V
Attention(Q, K, V) = Softmax(sum(tanh(Linear([Q, K])))) · V
Attention(Q, K, V) = Softmax((Q · K^T)/d) · V

作用

        注意力机制大多数应用于编码器和解码器模型

        在编码器端,相当于特征提取

        在解码器端,提升信息存储的效果

实现

"""
使用第一个公式
"""

import torch
import torch.nn as nn
import torch.nn.functional as F


class Attention(nn.Module):
    def __init__(self, query_dim, key_dim, value_dim1, value_dim2, output_dim):
        """
        初始化函数:
        - query_dim: query 向量的最后一维大小
        - key_dim: key 向量的最后一维大小
        - value_dim1: value 向量的倒数第二维大小
        - value_dim2: value 向量的倒数第一维大小
        - output_dim: 输出向量的最后一维大小
        """
        super(Attention, self).__init__()
        self.query_dim = query_dim
        self.key_dim = key_dim
        self.value_dim1 = value_dim1
        self.value_dim2 = value_dim2
        self.output_dim = output_dim

        # 初始化用于计算注意力权重的线性层
        self.attention_layer = nn.Linear(self.query_dim + self.key_dim, value_dim1)

        # 初始化用于计算最终输出的线性层
        self.output_layer = nn.Linear(self.query_dim + value_dim2, output_dim)

    def forward(self, query, key, value):
        """
        前向传播函数:
        输入参数:
        - query: 形状为 (batch_size, query_dim) 的张量
        - key: 形状为 (batch_size, key_dim) 的张量
        - value: 形状为 (batch_size, value_dim1, value_dim2) 的张量

        返回:
        - output: 形状为 (1, batch_size, output_dim) 的张量
        - attention_weights: 形状为 (batch_size, value_dim1) 的注意力权重张量
        """
        # 1. 拼接 Q 和 K, 计算注意力权重
        attention_weights = F.softmax(
            self.attention_layer(torch.cat((query[0], key[0]), dim=1)), dim=1)

        # 2. 使用注意力权重与 V 进行矩阵乘法,并与 Q 拼接
        output = torch.cat(
            (query[0], torch.bmm(attention_weights.unsqueeze(0), value)[0]), dim=1)

        # 3. 通过线性层计算最终输出,并扩展维度
        output = self.output_layer(output).unsqueeze(0)

        return output, attention_weights


if __name__ == "__main__":
    # 定义输入的维度
    query_dim = 64
    key_dim = 64
    value_dim1 = 10
    value_dim2 = 32
    output_dim = 128

    # 初始化 Attention 模块
    attention = Attention(query_dim, key_dim, value_dim1, value_dim2, output_dim)

    # 创建模拟的输入数据
    batch_size = 1

    # 随机生成 query, key 和 value 张量
    query = torch.randn(batch_size, query_dim).unsqueeze(0)  # 形状: (1, batch_size, query_dim)
    key = torch.randn(batch_size, key_dim).unsqueeze(0)  # 形状: (1, batch_size, key_dim)
    value = torch.randn(batch_size, value_dim1, value_dim2)  # 形状: (batch_size, value_dim1, value_dim2)

    # 调用 Attention 模块进行前向传播
    output, attention_weights = attention(query, key, value)

    # 输出结果
    print("Output shape:", output.shape)  # 输出张量的形状
    # print("Output:", output)  # 输出张量的值
    print("Attention Weights shape:", attention_weights.shape)  # 注意力权重张量的形状
    # print("Attention Weights:", attention_weights)  # 注意力权重张量的值
  • 8
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值