einops与torch.einsum

一句话总结:einops负责变形操作,einsum负责乘法与加法操作

参考https://zhuanlan.zhihu.com/p/101157166
https://zhuanlan.zhihu.com/p/372692913

einops


from einops import rearrange,repeat,reduce

import torch

rearrange

做维度操作,比如拉平,拼接,调换维度顺序,分patch等
eg.
output = rearrange(a, 'c (r p) w -> c r p w', p=3)把本来的(c,h,w)的h拆成(r,p),指定p的值,自动计算r的值
rearrange(images, 'b h w c -> (b h) w c')把b和h合并在一起,这是什么意思?
h是图像的高,b是图像的数量,合并之后,相当于图像的高度扩大了b倍,实际上就是把b张图像垂直拼接在一起
更复杂的:
image = rearrange(images, '(b1 b2) h w c -> (b1 h) (b2 w) c', b1=2)
(1)(b,h,w,c)先被拆成了(b1,b2,h,w,c),也就是说b张图片变成了b1组,每组b2张图片
(2)b1和h合并,b2和w合并,根据上一个例子,我们可以知道,就是把b1张图片垂直拼接,b2张图片水平拼接,最后的形状是(b1 h) (b2 w) c,从后往前看,可以理解为每组b2张图片先水平拼接,再把所有组拼接好的图片垂直拼接

分patch操作:
image = rearrange(images, 'b (h p_h) (w p_w) c -> b (h w) p_h p_w c', p_h = 150, p_w = 200) # 一张图划分为4个patch

reduce

一般用来做pooling操作
reduce(images, 'b h w c -> b h w', reduction='mean')AVG pooling,最后c没有了,可以知道是对通道求均值,每个位置的所有通道求均值
reduce(images, '(b1 b2) h w c -> (b2 h) (b1 w)', reduction='mean', b1=2)对什么求均值,什么就会消失,同时还带了一个rearrange的操作,先分为(b1,b2),再拼接,再求均值

repeat

repeat(image, 'b h w -> (b h) w c', c=3)在channel上重复3次,再吧b和h合并,也就是b张图片垂直拼接

torch.einsum

一般用于矩阵乘法
在这里插入图片描述
np.einsum('ij,jk->ik', A, B)
(1)先看是否有维度重复
若有重复,则按此维度相乘
如上图所示,j和j重复,即矩阵1的行与矩阵2的列相乘
(2)再看是否有维度消失
j这个维度消失了,说明相乘后还得相加,即矩阵乘法
若没消失,即np.einsum('ij,jk->ijk', A, B)
相乘后不相加,第一行与第一列相乘,得到新的一列——>第一行与矩阵2相乘,得到新的矩阵1,第二行与矩阵2相乘,得到新的矩阵2,第3行与矩阵2相乘,得到新的矩阵3,3个新的矩阵堆叠起来,维度即为(3,3,3):3个3x3的矩阵
在这里插入图片描述

这是一个crossattention模块:class CrossAttention(nn.Module): def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.): super().__init__() inner_dim = dim_head * heads context_dim = default(context_dim, query_dim) self.scale = dim_head ** -0.5 self.heads = heads self.to_q = nn.Linear(query_dim, inner_dim, bias=False) self.to_k = nn.Linear(context_dim, inner_dim, bias=False) self.to_v = nn.Linear(context_dim, inner_dim, bias=False) self.to_out = nn.Sequential( nn.Linear(inner_dim, query_dim), nn.Dropout(dropout) ) def forward(self, x, context=None, mask=None): h = self.heads q = self.to_q(x) context = default(context, x) k = self.to_k(context) v = self.to_v(context) q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) # force cast to fp32 to avoid overflowing if _ATTN_PRECISION =="fp32": with torch.autocast(enabled=False, device_type = 'cuda'): q, k = q.float(), k.float() sim = einsum('b i d, b j d -> b i j', q, k) * self.scale else: sim = einsum('b i d, b j d -> b i j', q, k) * self.scale del q, k if exists(mask): mask = rearrange(mask, 'b ... -> b (...)') max_neg_value = -torch.finfo(sim.dtype).max mask = repeat(mask, 'b j -> (b h) () j', h=h) sim.masked_fill_(~mask, max_neg_value) # attention, what we cannot get enough of sim = sim.softmax(dim=-1) out = einsum('b i j, b j d -> b i d', sim, v) out = rearrange(out, '(b h) n d -> b n (h d)', h=h) return self.to_out(out) 我如何从中提取各个提示词的注意力热力图并用Gradio可视化?
07-15
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值