一句话总结:einops负责变形操作,einsum负责乘法与加法操作
参考https://zhuanlan.zhihu.com/p/101157166
https://zhuanlan.zhihu.com/p/372692913
einops
from einops import rearrange,repeat,reduce
import torch
rearrange
做维度操作,比如拉平,拼接,调换维度顺序,分patch等
eg.
output = rearrange(a, 'c (r p) w -> c r p w', p=3)
把本来的(c,h,w)的h拆成(r,p),指定p的值,自动计算r的值
rearrange(images, 'b h w c -> (b h) w c')
把b和h合并在一起,这是什么意思?
h是图像的高,b是图像的数量,合并之后,相当于图像的高度扩大了b倍,实际上就是把b张图像垂直拼接在一起
更复杂的:
image = rearrange(images, '(b1 b2) h w c -> (b1 h) (b2 w) c', b1=2)
(1)(b,h,w,c)先被拆成了(b1,b2,h,w,c),也就是说b张图片变成了b1组,每组b2张图片
(2)b1和h合并,b2和w合并,根据上一个例子,我们可以知道,就是把b1张图片垂直拼接,b2张图片水平拼接,最后的形状是(b1 h) (b2 w) c
,从后往前看,可以理解为每组b2张图片先水平拼接,再把所有组拼接好的图片垂直拼接
分patch操作:
image = rearrange(images, 'b (h p_h) (w p_w) c -> b (h w) p_h p_w c', p_h = 150, p_w = 200) # 一张图划分为4个patch
reduce
一般用来做pooling操作
reduce(images, 'b h w c -> b h w', reduction='mean')
AVG pooling,最后c没有了,可以知道是对通道求均值,每个位置的所有通道求均值
reduce(images, '(b1 b2) h w c -> (b2 h) (b1 w)', reduction='mean', b1=2)
对什么求均值,什么就会消失,同时还带了一个rearrange的操作,先分为(b1,b2),再拼接,再求均值
repeat
repeat(image, 'b h w -> (b h) w c', c=3)
在channel上重复3次,再吧b和h合并,也就是b张图片垂直拼接
torch.einsum
一般用于矩阵乘法
np.einsum('ij,jk->ik', A, B)
(1)先看是否有维度重复
若有重复,则按此维度相乘
如上图所示,j和j重复,即矩阵1的行与矩阵2的列相乘
(2)再看是否有维度消失
j这个维度消失了,说明相乘后还得相加,即矩阵乘法
若没消失,即np.einsum('ij,jk->ijk', A, B)
即相乘后不相加,第一行与第一列相乘,得到新的一列——>第一行与矩阵2相乘,得到新的矩阵1,第二行与矩阵2相乘,得到新的矩阵2,第3行与矩阵2相乘,得到新的矩阵3,3个新的矩阵堆叠起来,维度即为(3,3,3):3个3x3的矩阵