问题描述
前一段升级了Pytorch(从0.4.0升级到了1.1.0),这几天用多块卡跑一个判别器用了WGAN-GP的网络。
当使用单张GPU时,一切正常;
但当使用多块GPU时,D Loss会爆炸!爆炸原因是其中的Gradient Penalty分量爆炸!
问题原因:
参考这篇震惊的文章:震惊!!!PyTorch实现的WGAN-GP竟会爆炸
发现是pytorch版本的问题。
解决方案:
把pytorch更新到最新版本(1.6.0)即可。
对于cuda version=10.2, Ubuntu系统,anaconda下的更新方法:
conda install pytorch torchvision cudatoolkit=10.2 -c pytorch
结果
升级到1.6.0之前:
红线代表判别器Loss,其出现爆炸的原因是其中的Gradient Penalty分量;
升级到1.6.0之后:
Loss变得正常。