Python常用库:rearrange函数——转换数组维度

Python常用库:rearrange函数——转换数组维度

einops.rearrange(tensor: Union[Tensor, List[Tensor]], pattern: str, **axes_lengths)

功能:重新划分张量维度,可以实现数组的转置、拆分、合并等操作。

输入:

  • tensor:需要调整维度的张量数据;
  • pattern:调整规则;
  • axes_lengths:附加的尺寸规格;

注意:

  • 如果需要将某一维度拆分成多个维度,需要额外指定一些附加的尺寸规格变量,同时拆分或者合并维度时,注意变量顺序;

代码案例

拆分

import torch
from einops import rearrange

data = torch.range(1, 10)
data1 = rearrange(data, '(a b) -> a b', a=2, b=5)
data2 = rearrange(data, '(b a) -> a b', a=2, b=5)
print(data1)
print(data2)

输出

# (a b) -> a b
tensor([[ 1.,  2.,  3.,  4.,  5.],
        [ 6.,  7.,  8.,  9., 10.]])
# (b a) -> a b
tensor([[ 1.,  3.,  5.,  7.,  9.],
        [ 2.,  4.,  6.,  8., 10.]])

注意:(a b) -> a b时,相当于直接按顺序拆分,每b个为1组,一共分出a组来,b看成每组的特征长度,a看成组数;(b a) -> a b时,相当于先把数据划分成(b, a)的,之后再做一次转置,即:

print(data.reshape(5, 2).transpose(-1, -2))

# 输出
tensor([[ 1.,  3.,  5.,  7.,  9.],
        [ 2.,  4.,  6.,  8., 10.]])

合并

import torch
from einops import rearrange

data = torch.range(1, 10).reshape(2, 5)
data1 = rearrange(data, 'a b -> (a b)')
data2 = rearrange(data, 'a b -> (b a)')
print(data)
print(data1)
print(data2)

输出

tensor([[ 1.,  2.,  3.,  4.,  5.],
        [ 6.,  7.,  8.,  9., 10.]])
# a b -> (a b)
tensor([ 1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9., 10.])
# a b -> (b a)
tensor([ 1.,  6.,  2.,  7.,  3.,  8.,  4.,  9.,  5., 10.])

注意:(a b) -> a b时,相当于直接按顺序合并,a个组的特征,按顺序串联合并;a b -> (b a)时,相当于先把数组做转置,之后再合并,即:

print(data.transpose(-1, -2).reshape(-1))

tensor([ 1.,  6.,  2.,  7.,  3.,  8.,  4.,  9.,  5., 10.])

官网文档:https://einops.rocks/api/rearrange/

  • 14
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

视觉萌新、

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值