torch.chunk(tensor, chunks, dim=0)
将tensor 拆分成相应的组块, 最后一块会小一些如果不能整除的话。
a的size是(2,4),那么torch.chunk(a,2,0), 0表示以第一个维度,分成两块,第一个维度就是指(2,4)里面的2
torch.split(tensor, split_size_or_sections, dim=0)
将tensor 拆分成相应的组块,split_size_or_sections指定分裂后该维度的维度值
a的size是(2,4),那么torch.split(a,1,0), 0表示以第一个维度,分块后每块第一个维度是1,第一个维度就是指(2,4)里面的2, 分开后的纬度值是(1,4)和(1,4)
详见代码块的注释
import torch
a = torch.arange(8).reshape(2,4)
b = torch.arange(9).reshape(3,3)
c = torch.arange(18).reshape(2,3,3)
print('a:',a, '\nb:', b,'\nc:', c)
a: tensor([[0, 1, 2, 3],
[4, 5, 6, 7]])
b: tensor([[0, 1, 2],
[3, 4, 5],
[6, 7, 8]])
c: tensor([[[ 0, 1, 2],
[ 3, 4, 5],
[ 6, 7, 8]],
[[ 9,