其实和原来的注意力是一样,相当于是对不同帧做了加权,所以这里的帧就是原注意力L(长度)。
import torch
import torch.nn as nn
import torch.nn.functional as F
class cross_frame_attn(nn.Module):
def __init__(self, embed_dim, n_heads, k_size, v_size):
super().__init__()
self.embed_dim = embed_dim
self.n_heads = n_heads
self.n_head_dim = self.embed_dim//self.n_heads
self.to_q = nn.Linear(k_size, self.embed_dim, bias=False)
self.to_k = nn.Linear(k_size, self.embed_dim, bias=False)
self.to_v = nn.Linear(v_size, self.embed_dim, bias=False)
self.q_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
self.k_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
self.v_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
self.to_out = nn.Linear(self.embed_dim, v_size, bias=False)
def forward(self, x, ):
# x.shape[B, num_frames, embed_size]
q = self.to_q(x)
k = self.to_k(x)
v = self.to_v(x)
q = self.q_proj(q)
k = self.k_proj(k)
v = self.v_proj(v)
B, f = v.shape[:2]
q = q.reshape(B, f, self.n_heads, -1).transpose(1, 2)
k = k.reshape(B, f, self.n_heads, -1).transpose(1, 2)
v = v.reshape(B, f, self.n_heads, -1).transpose(1, 2)
import math
score = torch.matmul(q, k.transpose(-1, -2))/math.sqrt(v.shape[-1])
score = F.softmax(score, dim=-1)
out = torch.matmul(score, v)
out = self.to_out(out.transpose(1, 2).reshape(B, f, -1))
return out