pytorch 完整的模型训练过程(入门级代码)

pytorch 模型训练过程(入门级代码)

加载数据集

Dataset是一个包装类,用来将数据包装为Dataset类,然后传入DataLoader中,我们再使用DataLoader来更加快捷的对数据进行操作。

Dataset

数据集包括训练集和测试集,在pytorch torchvision中有些常见的数据集,可以直接写代码设置参数下载。例如下载CIFAR10数据集:

train_data=torchvision.datasets.CIFAR10(root="../data",transform=torchvision.transforms.ToTensor(),train=True,download=True)
test_data=torchvision.datasets.CIFAR10(root="../data",transform=torchvision.transforms.ToTensor(),train=False,download=True)

download为True时会自动下载,下载的数据为PIL格式,要转为tensor。

每个数据集下载时的参数设置可能不一样,可以去官方文档具体查看。

如果是自己的数据集,大多数时候需要重写dataset类以方便对数据操作,入门级先不做介绍。

DataLoader

train_dataloader=DataLoader(dataset=train_data,batch_size=64,shuffle=True,num_works=0,drop_last=
  • 6
    点赞
  • 29
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值