torch.split(tensor, split_size, dim)函数对张量tensor在指定维度(dim)按参数split_size_or_sections进行分割,最终返回分割后的张量组成的元组(turple);
# 其中split_size可以是区间[ 1, tensor.shape[dim] ]的整数,表示每一次在指定维度(dim)分割的值(行数/列数),最终分块张量的数量==tensor.shape[dim] / split_size的值
# 若除不尽,则分块张量的数量为最大商+1(有余数>=1);若可整除,即为商;
# split_size也可是元素个数不超过tensor.shape[dim]且元素和=tensor.shape[dim]的列表或元组,列表元素个数即分块张量数;
# 列表元素指定分块方式,即每一次在指定维度(dim)分割的值(行数/列数),这里每次分割值都可能不同!
# 总体来说,torch.split(tensor, split_size, dim)函数适用于分块数不能被tensor.shape[dim]整除的情形,像是高配版的torch.chunk(tensor, sections, dim)函数。
# 例1:split_size是一个整数
# 创建一个shape=(2, 4, 6)的随机张量
split_tensor = torch.randn(size=(2, 4, 6)) # 张量元素服从标准正态分布N(0, I)
print(split_tensor)
"""
tensor([[[-0.0252, 0.8703, 0.1552, -0.0961, 0.4302, 0.0788],
[-0.4389, -0.0292, 0.1837, 0.9859, -0.3977, 0.5684],
[ 0.8940, -1.4846, 0.7611, 0.0483, -2.0573, 2.1025],
[-0.5657, -0.3805, 1.2321, 1.5162, 0.6435, -0.1696]],
[[ 0.4291, -0.2098, 0.6542, 1.1694, 0.2017, -0.0526],
[-1.2315, 0.0151, 0.6965, 0.3926, 0.3974, 0.8113],
[-1.8101, -0.0031, 0.8198, -1.3040, 1.0232, 1.2221],
[-0.6071, 1.5682, -0.2740, -0.2582, 0.4433, 0.5099]]])
"""
# 对维度(dim=1)进行分割
# split_size=2, dim=1
split_tensor1_1 = split_tensor.split(split_size=2, dim=1)
print(split_tensor1_1)
"""
(tensor([[[-0.0252, 0.8703, 0.1552, -0.0961, 0.4302, 0.0788],
[-0.4389, -0.0292, 0.1837, 0.9859, -0.3977, 0.5684]],
[[ 0.4291, -0.2098, 0.6542, 1.1694, 0.2017, -0.0526],
[-1.2315, 0.0151, 0.6965, 0.3926, 0.3974, 0.8113]]]),
tensor([[[ 0.8940, -1.4846, 0.7611, 0.0483, -2.0573, 2.1025],
[-0.5657, -0.3805, 1.2321, 1.5162, 0.6435, -0.1696]],
[[-1.8101, -0.0031, 0.8198, -1.3040, 1.0232, 1.2221],
[-0.6071, 1.5682, -0.2740, -0.2582, 0.4433, 0.5099]]]))
我们指定处理维度dim=1,而split_tensor.shape == (2, 4, 6),故split_tensor.shape[dim=1]==4,而split_size=2,因此在维度dim=1上(其他维度不变),每次分割值为2,
可分割次数 == split_tensor.shape[dim=1] / split_size == 2,故分块张量个数为2,最终返回两个分割张量组成的元组。
"""
print(split_tensor1_1[0].shape) # torch.Size([2, 2, 6]) == split_tensor1_1[1].shape
# split_size=3, dim=1
split_tensor1_2 = split_tensor.split(split_size=3, dim=1)
print(split_tensor1_2)
print(split_tensor1_2[0].shape) # torch.Size([2, 3, 6])
print(split_tensor1_2[1].shape) # torch.Size([2, 1, 6])
"""
(tensor([[[-0.0252, 0.8703, 0.1552, -0.0961, 0.4302, 0.0788],
[-0.4389, -0.0292, 0.1837, 0.9859, -0.3977, 0.5684],
[ 0.8940, -1.4846, 0.7611, 0.0483, -2.0573, 2.1025]],
[[ 0.4291, -0.2098, 0.6542, 1.1694, 0.2017, -0.0526],
[-1.2315, 0.0151, 0.6965, 0.3926, 0.3974, 0.8113],
[-1.8101, -0.0031, 0.8198, -1.3040, 1.0232, 1.2221]]]),
tensor([[[-0.5657, -0.3805, 1.2321, 1.5162, 0.6435, -0.1696]],
[[-0.6071, 1.5682, -0.2740, -0.2582, 0.4433, 0.5099]]]))
这里split_tensor.shape[dim=1] / split_size == 4 / 3 = 1 + 余数;(除不尽),故分割张量数=1+1=2;
按理来说每次在dim=1上的分隔值都是3,但是分割第一次之后只剩1,所以就有多少剩多少了(第二次分割值只能是1了),一般来说,分割值>=1且=split_size or 余数:
yu_shu = split_tensor.shape[dim=1] % split_size # 这里为了方便才这么写,实际上应是np.array(split_tensor).shape[dim]
if yu_shu = 0:
split_num = split_tensor.shape[dim=1] / split_size
else:
split_num = (split_tensor.shape[dim=1] - yu_shu) / split_size + 1
assert split_num >= 1
[注:这里只是提供思路的伪代码]
"""
# 当split_size == split_tensor.shape[dim]时,只需一次分割,故最终返回原张量组成的元组(turple)
split_tensor1_3 = split_tensor.split(split_size=4, dim=1)
print(split_tensor1_3)
"""
(tensor([[[-0.0252, 0.8703, 0.1552, -0.0961, 0.4302, 0.0788],
[-0.4389, -0.0292, 0.1837, 0.9859, -0.3977, 0.5684],
[ 0.8940, -1.4846, 0.7611, 0.0483, -2.0573, 2.1025],
[-0.5657, -0.3805, 1.2321, 1.5162, 0.6435, -0.1696]],
[[ 0.4291, -0.2098, 0.6542, 1.1694, 0.2017, -0.0526],
[-1.2315, 0.0151, 0.6965, 0.3926, 0.3974, 0.8113],
[-1.8101, -0.0031, 0.8198, -1.3040, 1.0232, 1.2221],
[-0.6071, 1.5682, -0.2740, -0.2582, 0.4433, 0.5099]]]),)
"""
# 例2:split_size是一个列表
# 创建一个形状为(1, 2, 3)的随机张量
split_tensor2 = torch.randn((1, 2, 3))
print(split_tensor2)
"""
tensor([[[-0.7475, 0.5178, -0.0279],
[-0.2505, 1.4757, 0.9539]]])
"""
# split_size = [2, 1], dim=2
split_tensor2_1 = split_tensor2.split(split_size=[2, 1], dim=2)
print(split_tensor2_1)
"""
(tensor([[[-0.7475, 0.5178],
[-0.2505, 1.4757]]]),
tensor([[[-0.0279],[ 0.9539]]]))
"""
print(split_tensor2_1[0].shape) # torch.Size([1, 2, 2])
print(split_tensor2_1[1].shape) # torch.Size([1, 2, 1])