插一个map
def sum(x): return x**2 m=(1,2,3) print(list(map(sum,m)))
[1, 4, 9]
rearrange
import torch
from einops import rearrange
images = torch.randn((32,30,40,3))
# (32, 30, 40, 3)
print(rearrange(images, 'b h w c -> b h w c').shape)
# (960, 40, 3)
print(rearrange(images, 'b h w c -> (b h) w c').shape)
# (30, 1280, 3)
print(rearrange(images, 'b h w c -> h (b w) c').shape)
# (32, 3, 30, 40)
print(rearrange(images, 'b h w c -> b c h w').shape)
# (32, 3600)
print(rearrange(images, 'b h w c -> b (c h w)').shape)
# ---------------------------------------------
# 这里(h h1) (w w1)就相当于h与w变为原来的1/h1,1/w1倍
# (128, 15, 20, 3)
print(rearrange(images, 'b (h h1) (w w1) c -> (b h1 w1) h w c', h1=2, w1=2).shape)
# (32, 15, 20, 12)
print(rearrange(images, 'b (h h1) (w w1) c -> b h w (c h1 w1)', h1=2, w1=2).shape)
repeat
import torch
from einops import repeat
image = torch.randn((30,40))
# 整体复制 (30, 40, 3)
print(repeat(image, 'h w -> h w c', c=3).shape)
# 按行复制 (60, 40)
print(repeat(image, 'h w -> (repeat h) w', repeat=2).shape)
# 按列复制 (30, 120) 注意:(repeat w)与(w repeat)结果是不同的
print(repeat(image, 'h w -> h (repeat w)', repeat=3).shape)
# (60, 80)
print(repeat(image, 'h w -> (h h2) (w w2)', h2=2, w2=2).shape)
reduce
import torch
from einops import reduce
x = torch.randn(3, 5, 5)
# (5, 5)
print(reduce(x, 'c h w -> h w', 'max').shape)
x = torch.randn(1, 3, 6, 6)
# (1, 3, 3, 3) 注意:如果不是整除会报错
y1 = reduce(x, 'b c (h h1) (w w1) -> b c h w', 'max', h1=2, w1=2)
print(y1.shape)
# Adaptive max-pooling:(1, 3, 3, 2)
print(reduce(x, 'b c (h h1) (w w1) -> b c h1 w1', 'max', h1=3, w1=2).shape)
# Global average pooling:(1, 3)
print(reduce(x, 'b c h w -> b c', 'mean').shape)
enisum
import torch
a = torch.randn((1,1,3,2))
b = torch.randn((1,1,1,2))
# 或 torch.einsum('b h i d, b h j d -> b h i j', [a,b])
# 相当于 torch.matmul(a,b.transpose(2,3))
c = torch.einsum('b h i d, b h j d -> b h i j', a, b)
print(c.shape)
torch.matmul