问题:代码断言错误,模型预测结果的box输出为NaN。
过程:在多卡模型训练时碰到了这个问题,github给出的解决方案五花八门,有降低学习率的,有人num_classes写错了的,但是都不行。
解决方案:关闭混合精度训练,即在训练中让amp=False,为了让自己的batch size大一些,魔改了作者的代码,结果最后是这里出了问题,估计是FP16精度不够溢出了。
额外记录:一个另外的问题时当时想在模型报错的时候打印box的结果,但是终端什么东西都没有,可能是多进程的原因,解决方案为报错时写入log文件而不是print()
额外更新:发现了新的解决方案,如果是半精度训练,则在adamw优化器中改动eps为1e-5,如下。eps是为了防止模型权重出现NaN的一个参数,而半精度训练是没有办法表示到1e-8的,所以改大一些可以防止这个问题。
if args.amp:
eps = 1e-5
else:
eps = 1e-8
optimizer = torch.optim.AdamW(
param_dicts, lr=args.lr, weight_decay=args.weight_decay, eps=eps
)