为GAN分别设置Dropout层与BN层的模式(pytorch)
应用场景是基于GAN进行图像生成时,Dropout层本身就可以被视为一种噪声输入用于增加生成图像的多样化,而BN层在只有单张图片seed输入时会报错。
因此,在完成训练后,实际使用场合中,需要将Dropout层设置为train模式,而BN层设置为eval模式,因此需要对模型整体调整后,再对部分层单独调整。
代码写法参考来源:BN的train,eval模式踩坑记录
实际写法(或者整体设置为测试模型再将Dropout层设置为训练模式):
# 需要设置为训练模式,以通过dropout产生更多的GAN多样性
# 但是BN层必须被设置为测试模式,以保证输出性能
for G_model in G_models:
G_model.train()
for _, module in G_model.named_modules():
# 将所有BN层设置为测试模式
if isinstance(module, nn.BatchNorm1d) or isinstance(module, nn.BatchNorm2d):
module.training = False
return G_models