目录
一、参数详解
torch.split(tensor, split_size_or_sections, dim=0)
含义:将一个张量分为几个chunks
torch.split(tensor, split_size_or_sections, dim=0)
我的理解:
split_size_or_sections参数,两两种类型,int值或者list数组,
int值,把tensor在维度dim上,按照容量split_size_or_sections进行等量切分。
如果不能整除split_size,那么最后一个chunk相较于其它chunk小;
如果是一个list列表,该方法会将tensor划分为len(split_size_or_sections)的张量。
dim:划分张量所依据的维度
return:返回的是一个tuple
int等量切分案例1
import torch
x = torch.arange(20).reshape(5, 4)
print(x)
y = torch.split(x, 3,0)
print(y)
y1 = torch.split(x, 3,1)
print(y1)
y是0维