pytorch之数据集划分

本文介绍了在深度学习中,如何通过PyTorch的random_split函数对数据集进行训练集、测试集和验证集的划分,以防止过拟合并提高模型的泛化能力。作者提供了实际的代码示例,展示了如何使用TensorDataset和设置数据划分比例。
摘要由CSDN通过智能技术生成

人工智能三个核心要素算法,算力和数据

1.数据集的划分

        在网络训练时,需要对数据集进行划分。一般划分为3部分:训练集、测试集、验证集。对数据集划分的目的是为了防止训练的模型过拟合,增加模型的准确性和泛化能力。

  1. 训练集:用来调试神经网络中的参数,通过计算损失函数进行网络参数的更新。
  2. 测试集:测试模型的训练效果,每个epoch后通过查看训练集和验证集的损失值变化关系,查看模型训练的效果。如果出现过拟合或不收敛等情况可以停止训练,通过排查数据、调整模型结构及修改超参数等方式进行优化。
  3. 验证集: 用来评估模型的准确性和泛化能力。

2.Pytorch代码实现

        通常情况下上述数据集需要保持互斥。一般情况下训练集和测试集是在同一批次的数据,验证集是不同批次或有差异场景下的数据。(本次实验则模拟使用了同一批次数据。)

        在pytorch中通过torch.utils.data.random_split()函数来实现数据集的划分,

        Random_split函数可以实现数据集的随机划分。

def random_split(dataset: Dataset[T], lengths: Sequence[int],
                 generator: Optional[Generator] = default_generator) -> List[Subset[T]]:
    r"""
    Randomly split a dataset into non-overlapping new datasets of given lengths.
    Optionally fix the generator for reproducible results, e.g.:

    >>> random_split(range(10), [3, 7], generator=torch.Generator().manual_seed(42))

    Args:
        dataset (Dataset): Dataset to be split
        lengths (sequence): lengths of splits to be produced
        generator (Generator): Generator used for the random permutation.
    """
    # Cannot verify that dataset is Sized
    if sum(lengths) != len(dataset):
        raise ValueError("Sum of input lengths does not equal the length of the input dataset!")

    indices = randperm(sum(lengths), generator=generator).tolist()
    return [Subset(dataset, indices[offset - length : offset]) for offset, length in zip(_accumulate(lengths), lengths)]

        参数:

              dataset (Dataset): 划分的数据集

              lengths (sequence): 被划分数据集的长度

             Generator(Optional[Generator]):随机生成器,可选择固定生成器以获得可复现的结果

             返回类型为 list

        示例代码如下:

import torch
from torch.utils.data import DataLoader, random_split

#生成数据集
image = torch.randn(500, 4) # 生成500个 每个元素均为标准正态分布的4列随机张量,即500行4列的张量
label = torch.randint(low=0, high=4,size=(500,)).float() #生成500个为值为0-3的label
dataset = torch.utils.data.TensorDataset(image, label) #将image和label组装成数据集

#设置训练集/测试集/验证集的比例,数据量较少时,一般设置为8:1:1
train_ratio = 0.8  #训练集所占比例
test_ratio = 0.1   #测试集所占比例
val_ration = 0.1   #验证集所占比例

#计算训练集&测试集的size
train_data_size = int(len(dataset) * train_ratio)
test_data_size = int(len(dataset) * test_ratio)
val_data_size = len(dataset) - train_data_size - test_data_size

#使用random_split()函数,随机划分数据集
#dataset (Dataset): 划分的数据集
#lengths (sequence): 被划分数据集的长度
#可选择固定生成器以获得可复现的结果
train_dataset, test_dataset, val_dataset = random_split(dataset=dataset,lengths=[train_data_size, test_data_size, val_data_size])

train_dataloder =  DataLoader(dataset=train_dataset, batch_size=8, shuffle=True, drop_last=False)
test_datasetder = DataLoader(dataset=test_dataset, batch_size=8, shuffle=True, drop_last=False)
val_dataloader = DataLoader(dataset=val_dataset, batch_size=8, shuffle=True, drop_last=False)

print(f"总数据大小为:{len(dataset)}")
print("训练集大小为:{0}".format(len(train_dataset)))
print("测试集大小为:{0}".format(len(test_dataset)))
print("验证集大小为:{0}".format(len(val_dataset)))

运行结果如下:

补充:torch.utils.data.TensorDataset() 用来对 tensor 进行组包该函数要求参数中的每个tensor第一维度必须相等。

3.总结

        使用torch.utils.data.random_split()函数可以方便快捷的对数据集进行随机划分,生成训练集测试集及验证集。

评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值
>