前言
导致pytorch的模型训练速度比较慢的原因最有可能的是三个:1. 数据导入环节,操作复杂 2.模型本身很复杂,数据流在模型中传递时过于耗时 3.loss函数计算复杂。
这其中第一个环节往往是最有可能的原因,第二,三个环节其实一回事;pytorch本身的框架针对这两个问题也做了大量的优化,如果不是专业技术过硬,建议在这两个环节上就不要过于纠结了,设计简单易用的模型才是正道。
数据导入环节的优化
数据导入环节,尤其是诸如图像等大张量从内存中的反复读取,以及后续的数据增广操作往往是造成训练速度低的主要原因。针对这个环节的加速其实有一些trick可以试用。
- 常规的做法
常规的做法主要是训练中各种现成的pytorch工具使用以及训练参数的设置,主要有如下的几种方案:
- 采用pytorch 自带的Dataloader而不是自己编写张量的导入类,Dataloader可以方便的设置cpu多线程,很多操作诸如张量的缩放都经过了优化
- 采用更大的batch size,设置cpu或者GPU能够承受的最大batch size,这最大的用处在于节省了后续梯度传递时花的时间
- 可以使用累积梯度,其实就是在cpu或者GPU能承受范围内,多次循环batch 再进行梯度计算
- 保存图,就像下面这样
losses = []
...
losses.append(loss)
print(f'current loss: {torch.mean(losses)'})
5. 使用多个GPU
当然这里推荐一个用于加速运行的插件,叫做pytorch lighting (https://github.com/williamFalcon/pytorch-lightning)
它的使用也是比较简单的
from pytorch_lightning import Trainer
model = LightningModule(…)
trainer = Trainer()
trainer.fit(model)
- 非常规的做法
非常规的做法就得视场合而定了,这些做法并不是对所有的应用场景有效,在不适合的场景里可能造成严重的训练质量下降。
- 半精度或者混合精度训练, 该方法在一些本来对张量精度要求不是很高的领域比较适用,可以显著的提高训练速度,同时显著降低运算显存开销,但是并不是所有领域都适合。关于半精度以及混合精度,可以采用apex library 在英伟达显卡上方便的实现。
- 提前规范化数据,比如大量的图像张量导入时,可以将图像提前缩放成2的n次方类型,这主要是因为大量的框架优化对于这种尺度的图像处理优化效果明显,而对于任意尺寸输入的图像不敢保证;但是提前的规范化有可能造成一些细节的变形
- 使用hdf5格式,提前将数据转成hdf5格式,这种格式对于cpu运算较为友好,同时也是受限比较小的一种方式;但是我在使用中发现,hdf5的解析有赖于自己写的方式,如果技术不过硬这里有可能还是解决不了问题。我这里有一个示例类,可供参考
class AdobePatchDataHDF5(data.Dataset):
def __init__(self, root, cropsize = 256, outputsize = 256):
fgfile = h5py.File(root, 'r')
self.root = root
self.fgfile = fgfile
self.cropsize = cropsize
self.outputsize = outputsize
def __getitem__(self, index):
# read image
fgimg = self.fgfile['img'][index, ...]
# random crop and resize, random flip with cv2
# toTensors
fgimg = fgimg.astype(np.float32) / 255.0
fgimg = torch.from_numpy(fg.transpose((2, 0, 1)))
# norm [0, 1] to [-1, 1]
return fgimg, label
def __len__(self):
return self.fgfile['img'].shape[0]
还有一些是同gpu绑定的方法,比如使用Nvidia DALI,(https://github.com/NVIDIA/DALI),这在预处理阶段可以进行极大的加速,但是目前的稳定版本(截止12.25)好像只能支持有限型号的显卡。
模型训练环节以及loss环节优化
这两部分的优化就比较专业了,需要过硬的本事来做平台上的优化,一般而言很难以取得效果,相反有可能造成较大的问题;比如我曾经手动从头书写DenseNet而不是采用pytorch自身的源代码,结果发现不但速度降低而且还造成显存消耗剧增;当然,当时由于时间问题,这事情就没有深究了。
一般来说,这一块的优化,主要是采用半精度或者混合精度的训练来达成;当然如果硬件允许,其实使用tensorRt来进行训练也是非常不错的选择。
综合
基本上训练缓慢的原因集中在第一点,对于这里的优化方案可供参考的也层出不穷;一般不推荐针对后两种的优化,那对于一般人来说较为复杂,稍不留意可能适得其反。