【pytorch】使用torch.utils.data.random_split()划分数据集

写在前面

不用自己写划分数据集的函数,pytorch已经给我们封装好了,那就是torch.utils.data.random_split()

用法详解

torch.utils.data.random_split(dataset, lengths, generator=<torch._C.Generator object>)

描述

随机将一个数据集分割成给定长度的不重叠的新数据集。可选择固定生成器以获得可复现的结果(效果同设置随机种子)。

参数

  • dataset (Dataset) – 要划分的数据集。
  • lengths (sequence) – 要划分的长度。
  • generator (Generator) – 用于随机排列的生成器。

示例

代码:

import torch
from torch.utils.data import random_split
dataset = range(10)
train_dataset, test_dataset = random_split(
    dataset=dataset,
    lengths=[7, 3],
    generator=torch.Generator().manual_seed(0)
)
print(list(train_dataset))
print(list(test_dataset))

输出:

[4, 1, 7, 5, 3, 9, 0]
[8, 6, 2]

torch.Generator().manual_seed(0)torch.manual_seed(0)的效果相同,我们验证一下。

代码:

import torch
from torch.utils.data import random_split
dataset = range(10)
torch.manual_seed(0)
train_dataset, test_dataset = random_split(
    dataset=dataset,
    lengths=[7, 3]
)
print(list(train_dataset))
print(list(test_dataset))

输出:

[4, 1, 7, 5, 3, 9, 0]
[8, 6, 2]

引用参考

https://pytorch.org/docs/stable/data.html#torch.utils.data.random_split

评论 13
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Xavier Jiezou

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值