代码中相应注释去掉,即可执行。einops需python3.7及以上环境。
rearrange, repeat, reduce等方法,网上已有相应使用记录。
import torch
from einops import rearrange, repeat, reduce, pack, unpack
import numpy as np
prompt = torch.rand((1, 10))
# # 对两个prompt, 在第0维度进行拼接
# x, leading_dims = pack([prompt, prompt], '* n')
# print(x.size()) # (2, 10)
# print(leading_dims) # [torch.Size([1]), torch.Size([1])]
packed_shapes = [(5,), (5,)] # 需是可迭代的,将最后一维分解为2部分,5+5=10
# packed_shapes也可以写作 packed_shapes = [torch.Size([5]), torch.Size([5])]
pattern = 'x *' # 星号*表示分解的维度
y= unpack(torch.rand((1, 10)), packed_shapes, pattern)# pattern
print(len(y)) # 输出为List类型,长度为2
print(y[0].size()) # torch.Size([1, 5])
# packed_shapes = [torch.Size([5]), torch.Size([5])] # 需是可迭代的,将最后一维分解为2部分,5+5=10
# pattern = 'x *' # 星号*表示分解的维度
# y= unpack(torch.rand((1, 10)), packed_shapes, pattern)# pattern
# print(len(y)) # 输出为List类型,长度为2
# print(y[0].size()) # torch.Size([1, 5])
############ array类型数据
h, w = 1, 2
image_rgb = np.random.random([h, w, 3])
image_depth = np.random.random([h, w])
# # pack在最后一维进行拼接, 先扩展为(h, w, 1),再在最后一维进行拼接,输出的维度为(h, w, 3+1)
# image_rgbd, ps = pack([image_rgb, image_depth], 'x y *')
# print(image_rgbd.shape) # (1, 2, 4)
# print(ps) # [(3,), ()]
# # 对两个(h,w)的张量,先扩展为(h, w, 1),再在最后一维进行拼接,输出的维度为(h, w, 2)
# image_rgbd, ps = pack([image_depth, image_rgb[:,:,0]], 'x y *')
# print(image_rgbd.shape) # (1, 2, 2)
# print(ps) # [(), ()]
# # 对两个(h,w)的张量,在最后一维(星号*所在位置)进行拼接,输出的维度为(h, 2w)
# image_rgbd, ps = pack([image_depth, image_rgb[:,:,0]], 'x *')
# print(image_depth.shape)
# print(image_rgbd.shape) # (1, 4)
# print(ps) # [(2,), (2,)]
# ###### unpack使用
# packed_shapes = [(4,), (1,), (1,)] # 需是可迭代的,将最后一维分解为3部分,4+1+1=6
# pattern = 'x *' # 星号*表示分解的维度
# y= unpack(np.random.random([h, 6]), packed_shapes, pattern)# pattern
# print(len(y)) # 输出为List类型,长度为3
# print(y[0].shape) # (h, 4)