torch.split()函数用法讲解[含代码实例]

torch.split(tensor, split_size, dim)函数对张量tensor在指定维度(dim)按参数split_size_or_sections进行分割,最终返回分割后的张量组成的元组(turple);


 

# 其中split_size可以是区间[ 1, tensor.shape[dim] ]的整数,表示每一次在指定维度(dim)分割的值(行数/列数),最终分块张量的数量==tensor.shape[dim] / split_size的值
# 若除不尽,则分块张量的数量为最大商+1(有余数>=1);若可整除,即为商;
# split_size也可是元素个数不超过tensor.shape[dim]且元素和=tensor.shape[dim]的列表或元组,列表元素个数即分块张量数;
# 列表元素指定分块方式,即每一次在指定维度(dim)分割的值(行数/列数),这里每次分割值都可能不同!
# 总体来说,torch.split(tensor, split_size, dim)函数适用于分块数不能被tensor.shape[dim]整除的情形,像是高配版的torch.chunk(tensor, sections, dim)函数。

# 例1:split_size是一个整数
# 创建一个shape=(2, 4, 6)的随机张量
split_tensor = torch.randn(size=(2, 4, 6))  # 张量元素服从标准正态分布N(0, I)
print(split_tensor)
"""
tensor([[[-0.0252,  0.8703,  0.1552, -0.0961,  0.4302,  0.0788],
         [-0.4389, -0.0292,  0.1837,  0.9859, -0.3977,  0.5684],
         [ 0.8940, -1.4846,  0.7611,  0.0483, -2.0573,  2.1025],
         [-0.5657, -0.3805,  1.2321,  1.5162,  0.6435, -0.1696]],
        [[ 0.4291, -0.2098,  0.6542,  1.1694,  0.2017, -0.0526],
         [-1.2315,  0.0151,  0.6965,  0.3926,  0.3974,  0.8113],
         [-1.8101, -0.0031,  0.8198, -1.3040,  1.0232,  1.2221],
         [-0.6071,  1.5682, -0.2740, -0.2582,  0.4433,  0.5099]]])
"""

# 对维度(dim=1)进行分割

# split_size=2, dim=1
split_tensor1_1 = split_tensor.split(split_size=2, dim=1)
print(split_tensor1_1)
"""
(tensor([[[-0.0252,  0.8703,  0.1552, -0.0961,  0.4302,  0.0788],
         [-0.4389, -0.0292,  0.1837,  0.9859, -0.3977,  0.5684]],
        [[ 0.4291, -0.2098,  0.6542,  1.1694,  0.2017, -0.0526],
         [-1.2315,  0.0151,  0.6965,  0.3926,  0.3974,  0.8113]]]), 
 tensor([[[ 0.8940, -1.4846,  0.7611,  0.0483, -2.0573,  2.1025],
         [-0.5657, -0.3805,  1.2321,  1.5162,  0.6435, -0.1696]],
        [[-1.8101, -0.0031,  0.8198, -1.3040,  1.0232,  1.2221],
         [-0.6071,  1.5682, -0.2740, -0.2582,  0.4433,  0.5099]]]))
我们指定处理维度dim=1,而split_tensor.shape == (2, 4, 6),故split_tensor.shape[dim=1]==4,而split_size=2,因此在维度dim=1上(其他维度不变),每次分割值为2,
可分割次数 == split_tensor.shape[dim=1] / split_size == 2,故分块张量个数为2,最终返回两个分割张量组成的元组。
"""
print(split_tensor1_1[0].shape)  # torch.Size([2, 2, 6]) == split_tensor1_1[1].shape

# split_size=3, dim=1
split_tensor1_2 = split_tensor.split(split_size=3, dim=1)
print(split_tensor1_2)
print(split_tensor1_2[0].shape)  # torch.Size([2, 3, 6])
print(split_tensor1_2[1].shape)  # torch.Size([2, 1, 6])
"""
(tensor([[[-0.0252,  0.8703,  0.1552, -0.0961,  0.4302,  0.0788],
         [-0.4389, -0.0292,  0.1837,  0.9859, -0.3977,  0.5684],
         [ 0.8940, -1.4846,  0.7611,  0.0483, -2.0573,  2.1025]],
        [[ 0.4291, -0.2098,  0.6542,  1.1694,  0.2017, -0.0526],
         [-1.2315,  0.0151,  0.6965,  0.3926,  0.3974,  0.8113],
         [-1.8101, -0.0031,  0.8198, -1.3040,  1.0232,  1.2221]]]), 
tensor([[[-0.5657, -0.3805,  1.2321,  1.5162,  0.6435, -0.1696]],
        [[-0.6071,  1.5682, -0.2740, -0.2582,  0.4433,  0.5099]]]))
这里split_tensor.shape[dim=1] / split_size == 4 / 3 = 1 + 余数;(除不尽),故分割张量数=1+1=2;
按理来说每次在dim=1上的分隔值都是3,但是分割第一次之后只剩1,所以就有多少剩多少了(第二次分割值只能是1了),一般来说,分割值>=1且=split_size or 余数:
yu_shu = split_tensor.shape[dim=1] % split_size  # 这里为了方便才这么写,实际上应是np.array(split_tensor).shape[dim]
if yu_shu = 0:
  split_num = split_tensor.shape[dim=1] / split_size
else:
  split_num = (split_tensor.shape[dim=1] - yu_shu) / split_size + 1

assert split_num >= 1
[注:这里只是提供思路的伪代码]
"""

# 当split_size == split_tensor.shape[dim]时,只需一次分割,故最终返回原张量组成的元组(turple)
split_tensor1_3 = split_tensor.split(split_size=4, dim=1)
print(split_tensor1_3)
"""
(tensor([[[-0.0252,  0.8703,  0.1552, -0.0961,  0.4302,  0.0788],
         [-0.4389, -0.0292,  0.1837,  0.9859, -0.3977,  0.5684],
         [ 0.8940, -1.4846,  0.7611,  0.0483, -2.0573,  2.1025],
         [-0.5657, -0.3805,  1.2321,  1.5162,  0.6435, -0.1696]],
        [[ 0.4291, -0.2098,  0.6542,  1.1694,  0.2017, -0.0526],
         [-1.2315,  0.0151,  0.6965,  0.3926,  0.3974,  0.8113],
         [-1.8101, -0.0031,  0.8198, -1.3040,  1.0232,  1.2221],
         [-0.6071,  1.5682, -0.2740, -0.2582,  0.4433,  0.5099]]]),)
"""

# 例2:split_size是一个列表

# 创建一个形状为(1, 2, 3)的随机张量
split_tensor2 = torch.randn((1, 2, 3))
print(split_tensor2)
"""
tensor([[[-0.7475,  0.5178, -0.0279],
         [-0.2505,  1.4757,  0.9539]]])
"""

# split_size = [2, 1], dim=2
split_tensor2_1 = split_tensor2.split(split_size=[2, 1], dim=2)
print(split_tensor2_1)
"""
(tensor([[[-0.7475,  0.5178],
         [-0.2505,  1.4757]]]), 
 tensor([[[-0.0279],[ 0.9539]]]))
"""
print(split_tensor2_1[0].shape)  # torch.Size([1, 2, 2])
print(split_tensor2_1[1].shape)  # torch.Size([1, 2, 1])

  • 10
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
### 回答1: torch.utils.data.random_splitPyTorch 中的一个数据集划分函数,用于将一个数据集随机划分为多个数据集。它接受两个参数:待划分的数据集和划分比例(比如 [0.8, 0.2] 表示将数据集划分为 80% 和 20% 两部分)。返回值是一个包划分出来的数据集的元组。 ### 回答2: torch.utils.data.random_splitPyTorch提供的一个非常有用的数据集划分函数,可以帮助我们将数据集划分为训练集和验证集。其功能是将一个数据集按照给定的比例随机划分为两个子集。 在机器学习中,通常需要将数据划分为训练集、验证集和测试集,以便对模型进行训练、验证和测试。划分数据集有多种方式,一种最常见的方式是将数据集按照50/50或80/20的比例随机分成训练集和验证集。 PyTorch提供的random_split函数可以帮助我们轻松地完成这个任务。该函数的主要输入是数据集和要划分的比例,它返回两个数据集,一个是训练集,另一个是验证集。这些数据集包输入和目标张量。在划分数据集之前,我们需要将原始数据集转换为PyTorch支持的Dataset类。 下面是torch.utils.data.random_split的使用示例代码: ```python from torch.utils.data import Dataset from torch.utils.data import DataLoader from torch.utils.data import random_split class IrisDataset(Dataset): def __init__(self, X, y): super(IrisDataset,self).__init__() self.X = X self.y = y def __getitem__(self, index): return self.X[index], self.y[index] def __len__(self): return len(self.X) # 创建数据集 dataset = IrisDataset(X, y) # 指定训练集和验证集的比例 train_ratio = 0.8 val_ratio = 0.2 # 计算划分的长度 train_len = int(train_ratio * len(dataset)) val_len = len(dataset) - train_len # 划分数据集 train_set, val_set = random_split(dataset, [train_len, val_len]) # 创建数据加载器 train_loader = DataLoader(dataset=train_set, batch_size=64, shuffle=True) val_loader = DataLoader(dataset=val_set, batch_size=64, shuffle=True) ``` 在上述示例代码中,我们创建了一个虚构的IrisDataset类,它包输入和目标张量。然后,我们创建了一个IrisDataset实例,并将其传递给random_split函数,以便将数据集划分为训练集和验证集。接下来,我们使用DataLoader创建训练集和验证集的迭代器。 总之,torch.utils.data.random_split是一个用于划分数据集的非常方便的函数,可以快速准确地进行训练集和验证集的分割。使用它可以帮助我们更好地管理数据集,并提高机器学习模型的性能。 ### 回答3: torch.utils.data.random_split是一个PyTorch中的数据集划分函数,用于将数据集按照一定比例随机划分为两个子集。该函数的输入参数为原始数据集dataset和划分比例,可以指定划分后子集的大小或比例。返回的结果是两个数据集对象,也可以进一步使用PyTorch提供的数据加载器对数据集进行操作。 在深度学习中,划分训练集、验证集和测试集是非常重要的步骤。可以通过将原始数据集按照一定比例划分为训练集和测试集,为模型评估和模型选择提供数据集的支持。在训练集中再将一部分数据划分为验证集,用于调整模型的超参数和防止模型出现过拟合。因此,使用torch.utils.data.random_split函数来随机划分数据集是非常有用的。 常见的划分方法如下: 1. 将原始数据集按照一定比例划分为训练集和测试集,比如常见的7:3或8:2的比例。 2. 在训练集中再将一部分数据划分为验证集,比如常见的8:1:1或者9:1的比例。 使用torch.utils.data.random_split函数,可以非常方便地实现这种随机划分,具体例子如下: ``` from torch.utils.data import DataLoader, Dataset, random_split class MyDataset(Dataset): def __init__(self, data_list): self.data_list = data_list def __getitem__(self, index): return self.data_list[index] def __len__(self): return len(self.data_list) data = [i for i in range(100)] dataset = MyDataset(data) train_size = int(0.8 * len(dataset)) test_size = len(dataset) - train_size train_dataset, test_dataset = random_split(dataset, [train_size, test_size]) train_loader = DataLoader(train_dataset, batch_size=10, shuffle=True) test_loader = DataLoader(test_dataset, batch_size=10, shuffle=False) ``` 上述代码将原始数据集按照8:2的比例随机划分为train_dataset和test_dataset两个数据集对象,其中train_size表示训练集大小,test_size表示测试集大小。最后再将划分后的数据集对象传入DataLoader构建数据加载器进行进一步处理。 实际应用中,可以根据具体任务需求进行相应的数据集划分方法选择和调整。同时也需要注意,随机划分数据集可能会引入一定的随机误差,因此需要多次重复实验,评估模型的平均表现。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值