基础实战——FashionMNIST时装分类
这里基础实战已经给出了pytorch基础实战代码,配置好环境后直接打开,逐步运行即可。
整个流程为
- 读取数据时,由于我们选择使用的是MNIST数据集,由PyTorch,torchvision提供的内置数据集。方式一直接下载速度较慢,选择第二种
- 自行构建Dataset类,提前下载好数据集,读入csv格式的数据
- 定义一个Dataloader类便于后续数据读取
- 设计训练模型,这里只是给了个较容易上手的CNN模型,中间模块化程序较多,需要自己理解。
- 设定损失函数,这里用的是自带的Crossentropy交叉熵损失。
- 设定优化器,选择Adam优化器
- 最后,将各部分代码封装,便于后续改进
- 根据结果?进行训练,调参?
- 至此模型训练完毕。
整体上,通过给定的简单的一个基于Pytorch深度学习训练模型,对深度学习的一般模型训练应用有了最基本的实践与认识,最近在上深度学习课程,这个team-learning可以说,提前多夯实了pytorch框架下,模型训练的一些理论知识,通过简单的动手实践,理解并熟悉了这样一套流程。加油!
更新,直到最后,还没意识到其中问题
- 前面都没啥异常,只是最后一步,训练模型的时候一直卡在那里,未果,发现中间并没有使用GPU,反复安装环境,手动下载,pip安装torch对应的cuda版本的GPU包,.whl文件,至此torch.cuda.is_available()结果为True
- jupter notebook内核显示内核挂掉了?将会立即重启?