1. torch.chunk(tensor, chunk_num, dim)与torch.cat()结果相反,
它是将tensor按dim(行或列)平均分割成chunk_num个tensor块,返回的是一个元组。
import torch
a = torch.Tensor([[1,2,4]])
b = torch.Tensor([[4,5,7], [3,9,8], [9,6,7]])
c = torch.cat((a,b), dim=0)
print(c)
print(c.size())
print('********************')
d = torch.chunk(c,3,dim=0) # 等价于torch.chunk(c,3,dim=0),平均分成2份
print(d)
print(len(d))
返回结果
[1]https://zhuanlan.zhihu.com/p/59141209
2. Pytorch数组反转(数组倒序)函数flip的使用
对n维张量的指定维度进行反转(倒序)