【知识补充】多头注意力和交叉注意力的区别

多头注意力和交叉注意力的区别

多头注意力和交叉注意力都是在自注意力的基础上发展而来的,它们的主要区别在于注意力矩阵的计算方式不同。

转载☞添加链接描述

多头注意力机制

多头注意力(Multi-Head Attention)是一种基于自注意力机制( self-attentionQ)的改进方法。自注意力是一种能够计算出输入序列中每个位置的权重,因此可以很好地处理序列中长距离依赖关系的问题。但在应用中,可能存在多个不同的关注点,因此就需要多个自注意力机制来处理不同的关注点。多头注意力就是在一个输入序列上使用多个自注意力机制,得到多组注意力结果,然后将这些结果进行拼接和线性投影得到最终输出。

多头注意力的优点是能够处理多个关注点的问题,可以较好地处理复杂语义关系。

多头注意力机制在计算注意力矩阵时,将输入张量 拆分成h 个子张量,每个子张量都是以不同的方式学习到的注意力信息。然后,对于每个子张量,都执行一次自注意力计算,得到一个输出张量 O。最后,将h 个输出张量拼接在一起,得到最终的输出张量 O
在这里插入图片描述

交叉注意力机制

交叉注意力(Cross-Attention)则是在两个不同序列上计算注意力,用于处理两个序列之间的语义关系。例如,在翻译任务中,需要将源语言句子和目标语言句子进行对齐,就需要使用交叉注意力来计算两个句子之间的注意力权重。

在这里插入图片描述

代码

import torch
import torch.nn as nn
import torch.nn.functional as F

class MultiHeadAttention(nn.Module):
    def __init__(self, in_dim, k_dim, v_dim, num_heads):
        super(MultiHeadAttention, self).__init__()
        self.num_heads = num_heads
        self.k_dim = k_dim
        self.v_dim = v_dim
        
        # 定义线性投影层,用于将输入变换到多头注意力空间
        self.proj_q = nn.Linear(in_dim, k_dim * num_heads, bias=False)
        self.proj_k = nn.Linear(in_dim, k_dim * num_heads, bias=False)
        self.proj_v = nn.Linear(in_dim, v_dim * num_heads, bias=False)
		# 定义多头注意力的线性输出层
        self.proj_o = nn.Linear(v_dim * num_heads, in_dim)
        
    def forward(self, x, mask=None):
        batch_size, seq_len, in_dim = x.size()
        # 对输入进行线性投影, 将每个头的查询、键、值进行切分和拼接
        q = self.proj_q(x).view(batch_size, seq_len, self.num_heads, self.k_dim).permute(0, 2, 1, 3)
        k = self.proj_k(x).view(batch_size, seq_len, self.num_heads, self.k_dim).permute(0, 2, 3, 1)
        v = self.proj_v(x).view(batch_size, seq_len, self.num_heads, self.v_dim).permute(0, 2, 1, 3)
        # 计算注意力权重和输出结果
        attn = torch.matmul(q, k) / self.k_dim**0.5   # 注意力得分
        
        if mask is not None:
            attn = attn.masked_fill(mask == 0, -1e9)
        
        attn = F.softmax(attn, dim=-1)   # 注意力权重参数
        output = torch.matmul(attn, v).permute(0, 2, 1, 3).contiguous().view(batch_size, seq_len, -1)   # 输出结果
        # 对多头注意力输出进行线性变换和输出
        output = self.proj_o(output)
        
        return output

class CrossAttention(nn.Module):
    def __init__(self, in_dim1, in_dim2, k_dim, v_dim, num_heads):
        super(CrossAttention, self).__init__()
        self.num_heads = num_heads
        self.k_dim = k_dim
        self.v_dim = v_dim
        
        self.proj_q1 = nn.Linear(in_dim1, k_dim * num_heads, bias=False)
        self.proj_k2 = nn.Linear(in_dim2, k_dim * num_heads, bias=False)
        self.proj_v2 = nn.Linear(in_dim2, v_dim * num_heads, bias=False)
        self.proj_o = nn.Linear(v_dim * num_heads, in_dim1)
        
    def forward(self, x1, x2, mask=None):
        batch_size, seq_len1, in_dim1 = x1.size()
        seq_len2 = x2.size(1)
        
        q1 = self.proj_q1(x1).view(batch_size, seq_len1, self.num_heads, self.k_dim).permute(0, 2, 1, 3)
        k2 = self.proj_k2(x2).view(batch_size, seq_len2, self.num_heads, self.k_dim).permute(0, 2, 3, 1)
        v2 = self.proj_v2(x2).view(batch_size, seq_len2, self.num_heads, self.v_dim).permute(0, 2, 1, 3)
        
        attn = torch.matmul(q1, k2) / self.k_dim**0.5
        
        if mask is not None:
            attn = attn.masked_fill(mask == 0, -1e9)
        
        attn = F.softmax(attn, dim=-1)
        output = torch.matmul(attn, v2).permute(0, 2, 1, 3).contiguous().view(batch_size, seq_len1, -1)
        output = self.proj_o(output)
        
        return output

在这里插入图片描述

多头注意力机制(MHSA)是一种注意力机制,它可以在不同的表示子空间中并行地计算多个注意力分数。这种机制可以帮助模型更好地捕捉输入序列中的不同关系。在图像分割中,MHSA通常被用于编码器的最后一层,以便模型可以同时关注整个图像。而交叉注意力机制则是将注意力机制应用于跳跃连接之后的解码器中,以将高层次语义更丰富的特征图与来自跳跃连接的高分辨率图结合起来,从而提高分割的准确性。 下面是一个简单的例子,展示了如何在PyTorch中实现多头注意力机制交叉注意力机制: ```python import torch import torch.nn as nn # 多头注意力机制 class MultiHeadAttention(nn.Module): def __init__(self, d_model, num_heads): super(MultiHeadAttention, self).__init__() self.num_heads = num_heads self.d_model = d_model assert d_model % num_heads == 0 self.depth = d_model // num_heads self.query = nn.Linear(d_model, d_model) self.key = nn.Linear(d_model, d_model) self.value = nn.Linear(d_model, d_model) self.fc = nn.Linear(d_model, d_model) def split_heads(self, x, batch_size): x = x.view(batch_size, -1, self.num_heads, self.depth) return x.permute(0, 2, 1, 3) def forward(self, query, key, value, mask=None): batch_size = query.size(0) # 线性变换 query = self.query(query) key = self.key(key) value = self.value(value) # 拆分头 query = self.split_heads(query, batch_size) key = self.split_heads(key, batch_size) value = self.split_heads(value, batch_size) # 计算注意力 scores = torch.matmul(query, key.transpose(-1, -2)) / torch.sqrt(torch.tensor(self.depth).float()) if mask is not None: scores += mask * -1e9 attention = nn.Softmax(dim=-1)(scores) context = torch.matmul(attention, value) # 合并头 context = context.permute(0, 2, 1, 3).contiguous() context = context.view(batch_size, -1, self.d_model) # 线性变换 output = self.fc(context) return output, attention # 交叉注意力机制 class CrossAttention(nn.Module): def __init__(self, d_model): super(CrossAttention, self).__init__() self.query = nn.Linear(d_model, d_model) self.key = nn.Linear(d_model, d_model) self.value = nn.Linear(d_model, d_model) self.fc = nn.Linear(d_model, d_model) def forward(self, query, key, value, mask=None): # 线性变换 query = self.query(query) key = self.key(key) value = self.value(value) # 计算注意力 scores = torch.matmul(query, key.transpose(-1, -2)) / torch.sqrt(torch.tensor(query.size(-1)).float()) if mask is not None: scores += mask * -1e9 attention = nn.Softmax(dim=-1)(scores) context = torch.matmul(attention, value) # 线性变换 output = self.fc(context) return output, attention ```
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值