参考:http://spytensor.com/index.php/archives/21/?ksxibs=60ei9
pytorch图片分类任务可以分为七部分:数据加载,模型定义,评测标准定义,训练过程定义,验证过程定义,测试过程定义,参数定义
文件组织结构:
- dataset/
--aug.py
--dataloder.py
- models/
--model.py
- utils/
--progress_bar.py
--utils.py
- checkpoint/
bestmodels/
- logs/
- submits/
- config.py
- main.py
- test.py
- dataset/:包含两个文件aug.py和dataloder.py。aug.py用于数据增强,dataloader.py用于数据加载。
- models/:存放一些自定义的模型,如果不使用pytorch自定义的网络模型,可以在这里添加(要添加__init__.py)
--model.py:定义模型加载
- utils/:定义常用的评测指标
--progress_bar.py:进度条输出工具
--utils.py:定义评测标准,比如accuracy,loss等
- checkpoint/:存放训练保存的模型
bestmodels/:保存在验证集上效果最好的模型
- logs/:存放训练日志
- config.py:参数定义文件,以参数类的形式定义所需提前设定或修改的参数,例如:数据路径,学习率,训练epoch
- main.py:主文件,包含训练,测试,验证等过程