1、一个完整的深度学习项目,一般包含以下功能
模型定义
数据处理和加载
训练模型
训练过程的可视化
测试
2、程序文件的组织结构:
checkpoints/ 用于保存训练好的模型,使得程序异常退出后仍能重新载入模型
data/ 数据相关操作(数据预处理等)
__init__.py
dataset.py
get_data.sh
models/ 模型,可以有多个模型,一个模型一个py
__init__.py
AlexNet.py
ResNet34.py
utils/ 可能用到的功能函数
__init__.py
visualize.oy
config.py 配置文件,所有的可配置变量都集中在此,并提供默认值
main.py 主文件,训练和测试入口
requirements.txt 程序依赖的第三方库
README.md
3、关于__init__.py
一个目录如果包含了__init__.py文件,那么它就变成了一个包。该文件可以为空,也可以定义包的属性和方法,但它必须存在。
4、数据加载
基本原理:用Dataset封装数据集,再用Dataloader实现数据并加载。
5、主文件
def train(**kwargs):
..
定义网络
定义数据
定义损失函数和优化器
计算