一.问题背景
在学习使用cross attention的时候我查阅了很多资料,发现里面说的都是cross attention的输入需要是相同维度的矩阵,但是我所需要的是可以处理不同维度数据的cross attention。
cross attention
二.cross attention的代码
看了关于cross attention的一些介绍和代码,发现大多都是这样
import torch
import torch.nn as nn
import torch.nn.functional as F
class CrossAttention(nn.Module):
def __init__(self, in_dim, out_dim):
super(CrossAttention, self).__init__()
self.in_dim = in_dim
self.out_dim = out_dim
self.query = nn.Linear(in_dim, out_dim, bias=False)
self.key = nn.Linear(in_dim, out_dim, bias=False)
self.value = nn.Linear(in_dim, out_dim, bias=False)
def forward(self, x, y):
batch_size = x.shape[0]
num_queries = x.shape[1]
num_keys = y.shape[1]
x = self.query(x)
y = self.key(y)
# 计算注意力分数
attn_scores = torch.matmul(x, y.transpose(-2, -1)) / (self.out_dim ** 0.5)
attn_weights = F.softmax(attn_scores, dim=-1)
# 计算加权和
V = self.value(y)
output = torch.bmm(attn_weights, V)
return output
这里的x和y所输入的维度需要一致,那么从代码上看好像不太好分析如何进行改变,我们先看看cross attention的公式:
Cross-Attention ( Q , K , V ) = softmax ( ( W Q S 2 ) ( W