经过一段时间的学习,对pytorch有了一定了解,于是总结了一个常见通用的pytoch深度学习代码结构,这个代码结构可以适用于大多数的深度学习任务。自己总结这个一个是将自己学到的东西写下来,以便以后查阅,同时可以给同样学习这个的朋友一个参考借鉴,方面交流,共同进步提高。这种代码存放结构个人比较整洁规范。好的代码都是经过很多次的重构,迭代出来的,我也会不断地学习,更新自己的知识体系,同样也会将最新的学习收获分享出来。
pytorch通用架构
如图所示的,这个框架起手9个文件夹,9个文件夹下分别放着不同用途的文件。
1)参数配置:文件夹下通常放一些json.yml或者parser等网络超参数文件。
2)父类基类:数据载入基类,训练基类,模型基类等基类文件
3)数据加载:存放数据集构建类文件,数据集载入文件
4)模型定义:model 文件夹,创建模型网络主要框架定义损失函数评测指标
5)训练和测试:train,test
6) 主训练器:存放处理训练流程控制代码,多进程, 多线程,checkpoint保存与重载
7)训练结果保存:存储训练过程中的pth文件,checkpoint文件,log日志
8)工具文件夹:utils,tensorboard等
9)bash脚本文件夹:用于存放各种bash脚本,用于控制不同的消融实验,检测不同模型对最终结果的影响。