使用einops库的这几个函数,可以有效地替代pytorch中的一些形状变换的操作。
import torch
from einops import rearrange,reduce,repeat
x = torch.randn(2,3,4,5)
# 1 transpose
out1 = x.transpose(1,2)
out2 = rearrange(x,'b i h w ->b h i w')
# 2 reshape
out1 = x.reshape(-1,4,5)
out2 = rearrange(x,'b i h w->(b i) h w')
out3 = rearrange(out2,'(b i) h w -> b i h w',b=2) # b=2
flag = torch.allclose(out3,x)
print(flag)
还可以应用在Transformer模型的img2patch中:
# img2patch
x = torch.randn(2,3,224,224)
out1 = rearrange(x,'b i (h1 p1) (w1 p2)-> b i (h1 w1) (p1 p2)',p1=16,p2=16) # p length of patch
print(out1.shape)
out2 = rearrange(out1,'b i n a -> b n (a i)') # [b,num_patch,patch_depth]
out2.shape
# flag = torch.allclose(out1,out2)
# flag
池化操作:
out1 = reduce(x,'b i h w -> b i h','mean') # avg pool
out2 = reduce(x,'b i h w -> b i h 1','sum') # keep dimension
out3 = reduce(x,'b i h w-> b i','max')
out3
最后是repate操作
# repeat
out1 = rearrange(x,'b i h w -> b i h w 1') #extend dim torch.unsqueeze
# print(out1)
out2 = repeat(out1,'b i h w 1 -> b i h w 2') # torch.tile
out3 = repeat(x,'b i h w -> b i (2 h) (2 w)')
out3.shape