cross attention交叉熵注意力机制

        交叉注意力(Cross-Attention)则是在两个不同序列上计算注意力,用于处理两个序列之间的语义关系。在两个不同的输入序列之间计算关联度和加权求和的机制。具体来说,给定两个输入序列,cross attention机制将一个序列中的每个元素与另一个序列中的所有元素计算关联度,并根据关联度对两个序列中的每个元素进行加权求和。这样的机制使模型能够建立不同序列之间的关联关系,并将两个序列的信息融合起来。例如,在翻译任务中,需要将源语言句子和目标语言句子进行对齐,就需要使用交叉注意力来计算两个句子之间的注意力权重。

        交叉注意力机制是一种特殊形式的多头注意力,它将输入张量拆分成两个部分 X1\epsilon R^{n*d1}  和 X2\epsilon R^{n*d2},然后将其中一个部分作为查询集合,另一个部分作为键值集合。它的输出是一个大小为n*d2 的张量,对于每个行向量,都给出了它对于所有行向量的注意力权重。

Q=X_{1} W^{Q} 和 K=V=X_{2} W^{K},则交叉注意力的计算如下:

\operatorname{CrossAttention}\left(X_{1}, X_{2}\right)=\operatorname{Softmax}\left(\frac{Q K^{T}}{\sqrt{d_{2}}}\right) V

import torch
import torch.nn as nn
import torch.nn.functional as F

class CrossAttention(nn.Module):
    def __init__(self, embed_dim, hidden_dim, num_heads):
        super(CrossAttention, self).__init__()
        self.embed_dim = embed_dim
        self.hidden_dim = hidden_dim
        self.num_heads = num_heads

        self.query_proj = nn.Linear(embed_dim, hidden_dim * num_heads)
        self.key_proj = nn.Linear(embed_dim, hidden_dim * num_heads)
        self.value_proj = nn.Linear(embed_dim, hidden_dim * num_heads)

        self.out_proj = nn.Linear(hidden_dim * num_heads, embed_dim)

    def forward(self, query, context):
        """
        query: (batch_size, query_len, embed_dim)
        context: (batch_size, context_len, embed_dim)
        """
        batch_size, query_len, _ = query.size()
        context_len = context.size(1)

        # Project input embeddings
        query_proj = self.query_proj(query).view(batch_size, query_len, self.num_heads, self.hidden_dim)
        key_proj = self.key_proj(context).view(batch_size, context_len, self.num_heads, self.hidden_dim)
        value_proj = self.value_proj(context).view(batch_size, context_len, self.num_heads, self.hidden_dim)

        # Transpose to get dimensions (batch_size, num_heads, len, hidden_dim)
        query_proj = query_proj.permute(0, 2, 1, 3)
        key_proj = key_proj.permute(0, 2, 1, 3)
        value_proj = value_proj.permute(0, 2, 1, 3)

        # Compute attention scores
        scores = torch.matmul(query_proj, key_proj.transpose(-2, -1)) / (self.hidden_dim ** 0.5)
        attn_weights = F.softmax(scores, dim=-1)

        # Compute weighted context
        context = torch.matmul(attn_weights, value_proj)

        # Concatenate heads and project output
        context = context.permute(0, 2, 1, 3).contiguous().view(batch_size, query_len, -1)
        output = self.out_proj(context)

        return output, attn_weights

# Example usage:
embed_dim = 512
hidden_dim = 64
num_heads = 8

cross_attention = CrossAttention(embed_dim, hidden_dim, num_heads)

# Dummy data
batch_size = 2
query_len = 10
context_len = 20

query = torch.randn(batch_size, query_len, embed_dim)
context = torch.randn(batch_size, context_len, embed_dim)

output, attn_weights = cross_attention(query, context)
print(output.size())  # Should be (batch_size, query_len, embed_dim)
print(attn_weights.size())  # Should be (batch_size, num_heads, query_len, context_len)
  1. 类定义CrossAttention 类继承自 nn.Module,包含初始化函数 __init__ 和前向传播函数 forward
  2. 初始化
    • 定义了一些线性变换层:query_proj, key_proj, 和 value_proj,这些层将嵌入向量转换为多头注意力机制所需的维度。
    • 最终的输出通过 out_proj 再投影回原始的嵌入维度。
  3. 前向传播
    • 输入的 querycontext 分别通过线性变换层,并重新整形以适应多头注意力机制。
    • 计算注意力分数,并通过 softmax 得到注意力权重。
    • 利用注意力权重加权上下文向量,得到新的上下文表示。
    • 最后将多头的结果合并,并通过输出投影层得到最终的输出。
  • 8
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值