假设现在共有10个数据,然后按照5:3:2的比例划分数据。
import torch
import torch.utils.data as D
x = torch.Tensor([10 - x + 100 for x in range(10)])
train_idx, val_idx, test_idx = D.random_split(x, [5, 3, 2])
# random_split函数返回的是一些D.dataset.Subset类(包含两个属性)
tmp = D.dataset.Subset # ctrl+左键见D.dataset.Subset类源码内容
print(test_idx)
print(type(test_idx))
# D.dataset.Subset的第二个属性indices是一个list,保存相应的索引
print(train_idx.indices)
print(val_idx.indices)
print(test_idx.indices)
print(type(train_idx.indices))
# dataset是第一个属性,该例子中数据类型为Tensor,保存的原来未分割的数据
print(train_idx.dataset)
print(type(train_idx.dataset))
# 最终要使用的划分数据如下
print(x[train_idx.indices])
print(x[val_idx.indices])
print(x[test_idx.indices])
输出结果如下所示:
<torch.utils.data.dataset.Subset object at 0x000001C502B9F208>
<class 'torch.utils.data.dataset.Subset'>
[4, 9, 3, 6, 0]
[7, 1, 8]
[5, 2]
<class 'list'>
tensor([110., 109., 108., 107., 106., 105., 104., 103., 102., 101.])
<class 'torch.Tensor'>
tensor([106., 101., 107., 104., 110.])
tensor([103., 109., 102.])
tensor([105., 108.])
End...