先抱怨一句,GAN训练真不容易…
一开始GAN总是训练效果很差,使用的是CNN去训练generator和discriminator,但是出来的总是“晶体图”,找到了一些技巧,然后自己实践了一下,也遇到了一些蠢哭的事情。
我换了一个网络,即调整了原来CNN的参数,或者多加入了一些层,结果出来的效果非常可观,看来首先要找到一个对的网络。
我尝试了用TensorFlow去实现GAN,一开始是用的Keras。实现第一遍的时候,其实还不错,使用的是全连接网络,没有使用CNN。
第二次想尝试WGAN,就跟着一位博主的文章开始了练习,但是敲完之后,结果非常糟糕,后来检查发现没有控制trainable_variables,在训练generator的时候,把discriminator的参数也更新了。
后来发现生成的图片都非常暗,感觉哪里有问题,但是有没发现,后来在看别人代码的时候才发现,我将图片的RGB值由(0,225)缩放到 (-1,1)的时候,我缩放错了,直接除以了255,结果缩放成了(0,1)。
下面是我觉得很不错的技巧:
-
输入的图片经过处理,将0-255的值变为-1到1的值。
images = (images/255.0)*2 - 1
-
在generator输出层使用
tanh
激励函数,使得输出范围在 (-1,1) -
保存生成的图片时,将矩阵值缩放到[0,1]之间
gen_image = (gen_image+1) / 2
-
使用
leaky_relu
激励函数,使得负值可以有一定的比重 -
使用BatchNormalization,使分布更均匀,最后一层不要使用。
-
在训练generator和discriminator的时候,一定要保证另外一个的参数是不变的,不能同时更新两个网络的参数。
-
如果训练很容易卡住,可以考虑使用WGAN
-
可以选择使用RMSprop optimizer
更多有用的技巧:
GAN训练技巧
Tips on train GAN