pytorch划分训练集、验证集与测试集(train_idx、val_idx、test_idx)

假设现在共有10个数据,然后按照5:3:2的比例划分数据。

import torch
import torch.utils.data as D

x = torch.Tensor([10 - x + 100 for x in range(10)])

train_idx, val_idx, test_idx = D.random_split(x, [5, 3, 2])

# random_split函数返回的是一些D.dataset.Subset类(包含两个属性)
tmp = D.dataset.Subset # ctrl+左键见D.dataset.Subset类源码内容
print(test_idx)
print(type(test_idx))

# D.dataset.Subset的第二个属性indices是一个list,保存相应的索引
print(train_idx.indices)
print(val_idx.indices)
print(test_idx.indices)
print(type(train_idx.indices))

# dataset是第一个属性,该例子中数据类型为Tensor,保存的原来未分割的数据
print(train_idx.dataset)
print(type(train_idx.dataset))

# 最终要使用的划分数据如下
print(x[train_idx.indices])
print(x[val_idx.indices])
print(x[test_idx.indices])



输出结果如下所示:

<torch.utils.data.dataset.Subset object at 0x000001C502B9F208>
<class 'torch.utils.data.dataset.Subset'>

[4, 9, 3, 6, 0]
[7, 1, 8]
[5, 2]
<class 'list'>

tensor([110., 109., 108., 107., 106., 105., 104., 103., 102., 101.])
<class 'torch.Tensor'>


tensor([106., 101., 107., 104., 110.])
tensor([103., 109., 102.])
tensor([105., 108.])

End... 

PyTorch 中,可以使用 `Subset` 和 `RandomSampler` 来实现循环划分数据集的训练集验证集。具体步骤如下: 1. 定义数据集 首先,需要定义一个数据集,假设这个数据集的名称为 `my_dataset`。 ```python from torch.utils.data import Dataset class MyDataset(Dataset): def __init__(self, data): self.data = data def __len__(self): return len(self.data) def __getitem__(self, idx): return self.data[idx] ``` 其中,`data` 是传入的数据列表,`__len__` 方法返回数据集的长度,`__getitem__` 方法根据索引返回对应的数据项。 2. 定义数据集的索引 接下来,需要定义数据集的索引,用于划分训练集验证集。假设数据集有 11 个数据项,需要将其划分为 9 个训练集和 2 个验证集。这可以通过定义一个 `index` 列表来实现。 ```python index = list(range(11)) ``` 3. 循环划分数据集 接下来,需要循环划分数据集。可以使用 `Subset` 和 `RandomSampler` 来实现。 ```python from torch.utils.data import Subset, RandomSampler train_sets = [] val_sets = [] for i in range(11): # 定义训练集验证集的索引 train_index = index[:i] + index[i+1:] val_index = index[i:i+1] # 定义训练集验证集的 Subset train_set = Subset(my_dataset, train_index) val_set = Subset(my_dataset, val_index) # 定义训练集验证集的 Sampler train_sampler = RandomSampler(train_set, replacement=True, num_samples=9) val_sampler = RandomSampler(val_set, replacement=True, num_samples=2) # 添加到训练集验证集列表中 train_sets.append((train_set, train_sampler)) val_sets.append((val_set, val_sampler)) ``` 在循环中,首先定义训练集验证集的索引,然后使用 `Subset` 分别定义训练集验证集。接着,使用 `RandomSampler` 来定义训练集验证集的采样器,这里采用随机采样的方式,每个采样器分别采样 9 个和 2 个数据项。最后,将训练集验证集以及对应的采样器添加到列表中。 这样,就可以得到 11 个训练集验证集以及对应的采样器,用于训练和验证模型。
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值