einops基础用法

基本概念

  • 自由索引:出现在箭头右边的索引, 可以遍历的索引。
  • 求和索引:只出现在箭头左边的索引,表示中间计算结果需要这个维度上求和之后才能得到输出。

基础规则

  • 规则一,equation 箭头左边,在不同输入之间重复出现的索引表示,把输入张量沿着该维度做乘法操作
  • 规则二,只出现在 equation 箭头左边的索引,表示中间计算结果需要在这个维度上求和,也就是上面提到的求和索引
  • 规则三,equation 箭头右边的索引顺序可以是任意的,比如上面的 “ik,kj->ij” 如果写成 “ik,kj->ji”,那么就是返回输出结果的转置

特殊规则

  • equation 可以不写包括箭头在内的右边部分,那么在这种情况下,输出张量的维度会根据默认规则推导。就是把输入中只出现一次的索引取出来,然后按字母表顺序排列,比如上面的矩阵乘法 “ik,kj->ij” 也可以简化为 “ik,kj”,根据默认规则,输出就是 “ij” 与原来一样。
  • equation 中支持 “…” 省略号,用于表示用户并不关心的索引,比如只对一个高维张量的最后两维做转置可以这么写
a = torch.zeros([2, 3, 4, 5, 6])
a = torch.einsum('...ij -> ...ji', a)
print(f"the dealed a shape is {a.shape}")
# the dealed a shape is torch.Size([2, 3, 4, 6, 5])

实例

# 取出对角线元素
b = torch.arange(16).reshape(4, 4)
print(b)
b = torch.einsum('ii->i', b)
print(b)
# tensor([[ 0,  1,  2,  3],
#         [ 4,  5,  6,  7],
#         [ 8,  9, 10, 11],
#         [12, 13, 14, 15]])
# tensor([ 0,  5, 10, 15])

# get sum
a = torch.arange(6).reshape(2, 3)
print(a)
a = torch.einsum('ij ->', a)
print(a)
# tensor([[0, 1, 2],
#         [3, 4, 5]])
# tensor(15)

# get sum by row、clo
a = torch.arange(6).reshape(2, 3)
print(f"sum example : \na = {a}")
print(torch.einsum('ij -> i', a))
print(torch.einsum('ij -> j', a))
# sum example : 
# a = tensor([[0, 1, 2],
#         [3, 4, 5]])
# tensor([ 3, 12])
# tensor([3, 5, 7])

用法表格请添加图片描述

在这里插入图片描述

rearange

import torch
from einops import rearrange
 
images = torch.randn((32,30,40,3))
# (32, 30, 40, 3)
print(rearrange(images, 'b h w c -> b h w c').shape)
 
# (960, 40, 3)
print(rearrange(images, 'b h w c -> (b h) w c').shape)
 
# (30, 1280, 3)
print(rearrange(images, 'b h w c -> h (b w) c').shape)
 
# (32, 3, 30, 40)
print(rearrange(images, 'b h w c -> b c h w').shape)
 
# (32, 3600)
print(rearrange(images, 'b h w c -> b (c h w)').shape)
 
# ---------------------------------------------
# 这里(h h1) (w w1)就相当于h与w变为原来的1/h1,1/w1倍
 
# (128, 15, 20, 3)
print(rearrange(images, 'b (h h1) (w w1) c -> (b h1 w1) h w c', h1=2, w1=2).shape)
 
# (32, 15, 20, 12)
print(rearrange(images, 'b (h h1) (w w1) c -> b h w (c h1 w1)', h1=2, w1=2).shape)

repeat

import torch
from einops import repeat
 
image = torch.randn((30,40))
 
# 整体复制 (30, 40, 3)
print(repeat(image, 'h w -> h w c', c=3).shape)
 
# 按行复制 (60, 40)
print(repeat(image, 'h w -> (repeat h) w', repeat=2).shape)
 
# 按列复制 (30, 120) 注意:(repeat w)与(w repeat)结果是不同的
print(repeat(image, 'h w -> h (repeat w)', repeat=3).shape)
 
# (60, 80)
print(repeat(image, 'h w -> (h h2) (w w2)', h2=2, w2=2).shape)

reduce

import torch
from einops import reduce
 
x = torch.randn(3, 5, 5)
# (5, 5)
print(reduce(x, 'c h w -> h w', 'max').shape)
 
x = torch.randn(1, 3, 6, 6)
# (1, 3, 3, 3) 注意:如果不是整除会报错
y1 = reduce(x, 'b c (h h1) (w w1) -> b c h w', 'max', h1=2, w1=2)
print(y1.shape)
 
# Adaptive max-pooling:(1, 3, 3, 2)
print(reduce(x, 'b c (h h1) (w w1) -> b c h1 w1', 'max', h1=3, w1=2).shape)
 
# Global average pooling:(1, 3)
print(reduce(x, 'b c h w -> b c', 'mean').shape)
  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值