自上一次做完模型训练加速的实验后,陆续又有了一些新的改动,包括:
- Pytorch发布了1.6.0版,官方支持
amp
功能,不再需要外部apex
库; DALI
库也更新了多次,一些API的使用上有些区别;- 研究了一下
DistributedDataParallel
的使用。
因此,重新梳理了训练框架,并将参考代码放到Github上。
如果觉得对你有所启发,请给个star呀。
参考代码在这儿。
1. 训练速度的瓶颈及应对思路
这边主要说的是CV领域,但在其他领域,思路应该也是相通的。
模型训练过程中,影响整体速度的因素主要有以下几点:
- 将数据从磁盘读取到内存的效率;
- 对图片进行解码的效率;
- 对样本进行在线增强的效率;
- 网络前向/反向传播和Loss计算的效率;
针对这几个因素,分别采取如下几种应对思路:
- 加快数据读取可以有几种思路:
- 采取类似TF的tfrecord或者Caffe的lmdb格式,提前将数据打包,比反复加载海量的小文件要快很多,但pytorch没有通用的数据打包方式;
- 在初始化时,提前将所有数据加载到内存中(前提是数据集不能太大,内存能装得下);
- 将数据放在SSD而非HDD,可以极大地提速(前提是你有足够大的SSD