cross attention输入不同维度的矩阵

在深度学习中,传统cross attention要求输入矩阵维度相同。但实际需求中可能遇到不同维度的数据。通过调整线性层输出,能实现不同维度输入的cross attention。具体做法是将线性层应用于查询向量Q、键向量K和值向量V,确保经过线性层后的维度匹配公式要求。
摘要由CSDN通过智能技术生成

一.问题背景

在学习使用cross attention的时候我查阅了很多资料,发现里面说的都是cross attention的输入需要是相同维度的矩阵,但是我所需要的是可以处理不同维度数据的cross attention。
cross attention

二.cross attention的代码

看了关于cross attention的一些介绍和代码,发现大多都是这样

import torch
import torch.nn as nn
import torch.nn.functional as F

class CrossAttention(nn.Module):
    def __init__(self, in_dim, out_dim):
        super(CrossAttention, self).__init__()
        self.in_dim = in_dim
        self.out_dim = out_dim
        self.query = nn.Linear(in_dim, out_dim, bias=False)
        self.key = nn.Linear(in_dim, out_dim, bias=False)
        self.value = nn.Linear(in_dim, out_dim, bias=False)

    def forward(self, x, y):
        batch_size = x.shape[0]
        num_queries = x.shape[1]
        num_keys = y.shape[1]
        x = self.query(x)
        y = self.key(y)
        # 计算注意力分数
        attn_scores = torch.matmul(x, y.transpose(-2, -1)) / (self.out_dim ** 0.5)
        attn_weights = F.softmax(attn_scores, dim=-1)
        # 计算加权和
        V = self.value(y)
        output = torch.bmm(attn_weights, V)
        
        return output

这里的x和y所输入的维度需要一致,那么从代码上看好像不太好分析如何进行改变,我们先看看cross attention的公式:

Cross-Attention ( Q , K , V ) = softmax ( ( W Q S 2 ) ( W

评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值