一种基于pyTorch的深度学习项目模板


前言

如同传统程序项目开发一样,在深度学习项目中,有很多程式化的、可封装的代码段。将此类代码段形成可重复使用的模板,将大大开发效率。下面即推荐一种基于pyTorch的深度学习项目模板。
模板地址:https://github.com/victoresque/pytorch-template


一、依赖

  • Python >= 3.5
  • PyTorch >= 0.4
  • tqdm
  • tensorboard >= 1.14

二、特点

  • 清晰的文件夹结构,适用于许多深度学习项目。

  • .json配置文件支持方便参数调整。

  • 可定制的命令行选项,用于更方便的参数调整。

  • 检查点(checkpoint)保存和恢复。

  • 用于更快开发的抽象基类:

    • BaseTrainer处理检查点保存/恢复、训练过程记录等。
    • BaseDataLoader处理batch的生成、数据清洗和训练集/验证集拆分。
    • BaseModel提供基本模型摘要。

三、文件夹结构

pytorch-template/
│
├── train.py - 启动模型训练的py文件
├── test.py - 测试模型的py文件
│
├── config.json - 控制训练的config文件
├── parse_config.py - 处理config文件及cli选项的py文件
│
├── new_project.py - 初始化新项目所需运行的py文件
│
├── base/ - 抽象基类
│   ├── base_data_loader.py
│   ├── base_model.py
│   └── base_trainer.py
│
├── data_loader/ - 数据处理与装入
│   └── data_loaders.py
│
├── data/ - 存放输入数据的默认文件夹
│
├── model/ - 模型,losses和评价指标
│   ├── model.py
│   ├── metric.py
│   └── loss.py
│
├── saved/
│   ├── models/ - 存放训练好的模型文件
│   └── log/ - 默认的tensorboard和log文件存放地址
│
├── trainer/ - trainer类
│   └── trainer.py
│
├── logger/ -tensorboard可视化及logging模块
│   ├── visualization.py
│   ├── logger.py
│   └── logger_config.json
│  
└── utils/ - 其他工具函数
    ├── util.py
    └── ...

四、用法

原repo是模板的MINST示例,可直接使用python train.py -c config.json运行。
当你需要开始一个新项目时,需要首先运行new_project.py。通过python new_project.py ../NewProject创建一个名为“NewProject”的新项目文件夹。该脚本将过滤掉不需要的文件,如cache、git 文件和README.md。

1. config文件

config.json文件详细内容如下所示:

{
  "name": "Mnist_LeNet",        // 项目名称
  "n_gpu": 1,                   // 用于训练的GPU数
  
  "arch": {
    "type": "MnistModel",       // 模型名称
    "args": {

    }                
  },
  "data_loader": {
    "type": "MnistDataLoader",         // 选择DataLoader
    "args":{
      "data_dir": "data/",             // 数据集所在路径
      "batch_size": 64,                // batch size
      "shuffle": true,                 // 在划分训练/验证集前是否打乱数据集
      "validation_split": 0.1          // 验证集的大小
      "num_workers": 2,                // 装入数据集时所开进程数量
    }
  },
  "optimizer": {
    "type": "Adam",
    "args":{
      "lr": 0.001,                     // 学习率
      "weight_decay": 0,               // (可选)权重衰减
      "amsgrad": true
    }
  },
  "loss": "nll_loss",                  // loss
  "metrics": [
    "accuracy", "top_k_acc"            // 评价指标
  ],                         
  "lr_scheduler": {
    "type": "StepLR",                  // 学习率scheduler
    "args":{
      "step_size": 50,          
      "gamma": 0.1
    }
  },
  "trainer": {
    "epochs": 100,                     // epochs
    "save_dir": "saved/",              // checkpoints文件会存储在save_dir/models/name
    "save_freq": 1,                    //
    "verbosity": 2,                    // 0: quiet, 1: per epoch, 2: full
  
    "monitor": "min val_loss"          // mode and metric for model performance monitoring. set 'off' to disable.
    "early_stop": 10	                 // number of epochs to wait before early stop. set 0 to disable.
  
    "tensorboard": true,               // 是否使用tensorboard
  }
}

2. 使用 config.json

修改好config.json文件,然后运行

python train.py --config config.json

3. 从检查点(checkpoint)继续

我们可以从上一检查点继续训练:

python train.py --resume path/to/checkpoint

4. 使用多个GPU

可以通过将配置文件的n_gpu参数设置地更大来启用多 GPU 训练。如果配置为使用比可用数量更少的 gpu,默认情况下将使用前 n 个设备,但我们仍可以通过--device来指定GPU。

python train.py --device 2,3 -c config.json

五、自定义

1. 自定义 CLI 选项

更改配置文件的值是调整超参数的一种干净、安全且简单的方法。但是,如果某些值需要过于频繁或快速地更改,有时最好使用命令行选项。
该模板默认使用存储在 json 文件中的配置,但你仍可以通过命令行选项的方式更改其中的一部分。

# simple class-like object having 3 attributes, `flags`, `type`, `target`.
CustomArgs = collections.namedtuple('CustomArgs', 'flags type target')
options = [
    CustomArgs(['--lr', '--learning_rate'], type=float, target=('optimizer', 'args', 'lr')),
    CustomArgs(['--bs', '--batch_size'], type=int, target=('data_loader', 'args', 'batch_size'))
    # options added here can be modified by command line flags.
]

2. DataLoader

我们可以自由编写自己的DataLoader。

  • 继承BaseDataLoader
    BaseDataLoadertorch.utils.data.DataLoader的子类,您可以使用其中任何一个。
    BaseDataLoader主要处理:
    • 生成下一个batch
    • 对数据集做shuffle
    • 通过调用BaseDataLoader.split_validation()做训练/验证集划分
  • DataLoader的使用
    BaseDataLoader是一个迭代器,用于迭代每一个batch:
for batch_idx, (x_batch, y_batch) in data_loader:
    pass

3. Trainer

我们可以自由编写自己的trainer。

  • 继承BaseTrainer
    BaseTrainer主要处理:

    • 训练过程记录

    • 检查点保存、恢复

    • 可重新配置的性能监控,用于保存当前的最佳模型,并提前停止训练。

      • 如果 config文件中,monitor设置为max val_accuracy,这意味着每个epoch结束后都会保存一个最佳模型model_best.pth
      • 如果 config文件中,early_stop被设置为true,当模型性能在给定数量的 epoch 内没有提高时,训练将自动终止。
  • 实现抽象方法
    你必须实现_train_epoch()。如果需要验证功能,则需要进一步实现_valid_epoch()。上述两种方法都在trainer/trainer.py里。

4. Model

我们可以自由编写自己的Model。

  • 继承BaseModel
    BaseModel主要处理:

    • 继承自torch.nn.Module
    • __str__:修改print函数以打印可训练参数的数量。

5. Loss

自定义损失函数可以在“model/loss.py”中实现。通过将config文件中“loss”更改为相应的名称来使用它们。

6. 评价指标(Metrics)

自定义的metrics可在“model/metrics.py”,方法同上。

  • 4
    点赞
  • 23
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Morty徐同学

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值