最近在用chain-model训练时偶尔会出现找不到GPU卡而训练终止。而很烦的是原本的代码,中途训练终止时,重新训练的话,又得重头开始跑,对于跑一遍流程需要好多天的我来说,真的伤不起。
所以通过查看源码,了解了chain-model加载预训练模型的方案,这样即使出现异常,也能从失败的地方继续训练。这个方法只需要在源码中简单的改几个地方即可。详细如下面介绍。
我在查看./steps/nnet3/chain/train.py的代码时,发现其中中含有--trainer.input-model的参数,如果指定的话,就会在初始化的时候,加载指定的参数作为预训练模型。
不过local/chain/run_tdnn.sh脚本中,没有使用到--trainer.input-model参数,所以要调整一下脚本,把这个参数透传出来。添加的代码如下图红框所示:
(1)在执行训练的指令里面,加入下图显示的代码:
(2)然后,代码的开头加入input_model参数,默认值为0.mdl。使得这个参数仍然可以透传出去。
(3)这样在执行local/chain/run_tdnn.sh的脚本里面或者单独执行的时候,就能指定输入模型来,下面是展示单独执行时的示例指令,供参考:
|
这里,--input_model后跟的参数即为我们指定的要预加载的模型。