torch.split(tensor, int/list, dim)
将tensor切分为块,每个块都是原始tensor的一个视图
- 如果传入int值,tensor将被切分成大小相同的块,如果切分维度大小与传入的值,不能整除,那么最后一块要小一些
- 如果传入一个list,按照list给出的大小序列进行切分,list的大小总和要与切分维度的大小相同
- 返回值回为一个张量元组
tensor.split(int/list, dim)
等价于torch.split(tensor, int/list, dim)
示例
import torch
t1 = torch.rand(4, 5, 8)
t2 = torch.split(t1, [5, 3], dim=2) # 等价:t2 = t1.split([5, 3], dim=2)
print(type(t2)) # <class 'tuple'>
print(t2[0].shape) # torch.Size([4, 5, 5])
print(t2[1].shape) # torch.Size([4, 5, 3])