WGAN-GP是针对WGAN的存在的问题提出来的,WGAN在真实的实验过程中依旧存在着训练困难、收敛速度慢的 问题,相比较传统GAN在实验上提升不是很明显。WGAN-GP在文章中指出了WGAN存在问题的原因,那就是WGAN在处理Lipschitz限制条件时直接采用了 weight clipping。相关讲解请参考
WGAN-GP的介绍
同往期一样,依然以生成cifar数据集中马的彩色图片为例,关于cifar数据集的读取和生成器模型的验证请参考第6期:
用DCGAN生成马的彩色图片
下面给出WGAN-GP框架
"""
-------------------------------------------------------生死看淡,不服就GAN-------------------------------------------------------------------------
PROJECT: CIFAR10_WGAN-GP
Author: Ephemeroptera
Date:2019-3-19
QQ:605686962
"""
"""
WGAN说明:相比较WGAN,WGAN-GP提出以下改进:
(1)用对判别器梯度惩罚取代WGAN的判决器权值区间截断
(2)判别器取消BN操作
(3)优化器使用ADAM
"""
import numpy as np
import tensorflow as tf
import pickle
import TFRecordTools
import time
real_shape = [-1,32,32,3]
data_total = 5000
batch_size = 64
noise_size = 128
max_iters = 50000
learning_rate = 5e-5
beta1 = 0.5
beta2 = 0.9
CRITIC_NUM = 5
lam = 10
def Generator_DC_32x32(z, channel, is_train=True):
"""
:param z: 噪声信号,tensor类型
:param channnel: 生成图片的通道数
:param is_train: 是否为训练状态,该参数主要用于作为batch_normalization方法中的参数使用(训练时候开启)
"""
with tf.variable_scope("generator", reuse=(not is_train)):
layer1 = tf.layers.dense(z, 4 * 4 * 512)
layer1 = tf.reshape(layer1, [-1, 4, 4, 512])
layer1 = tf.layers.batch_normalization(layer1, training=is_train,)
layer1 = tf.nn.relu(layer1)
layer2 = tf.layers.conv2d_transpose(layer1, 256, 3, strides=2, padding='same',
kernel_initializer=tf.random_normal_initializer(0, 0.02),
bias_initializer=tf.random_normal_initializer(0, 0.02))
layer2 = tf.layers.batch_normalization(layer2, training=is_train)
layer2 = tf.nn.relu(layer2)
layer3 = tf.layers.conv2d_transpose(layer2, 128, 3, strides=2, padding='same',
kernel_initializer=tf.random_normal_initializer(0, 0.02),
bias_initializer=tf.random_normal_initializer(0, 0.02))
layer3 = tf.layers.batch_normalization(layer3, training=is_train)
layer3 = tf.nn.relu(layer3)
layer4 = tf.layers.conv2d_transpose(layer3, 64, 3, strides=2, padding='same',
kernel_initializer=tf.random_normal_initializer(0, 0.02),
bias_initializer=tf.random_normal_initializer(0, 0.02))
layer4 = tf.layers.batch_normalization(layer4, training=is_train)
layer4 = tf.nn.relu(layer4)
logits = tf.layers.conv2d_transpose(layer4, channel, 3, strides=1, padding='same',
kernel_initializer=tf.random_normal_initializer(0, 0.02),
bias_initializer=tf.random_normal_initializer(0, 0.02))
outputs = tf.tanh(logits)
return logits,outputs
def Discriminator_DC_32x32(inputs_img, reuse=False, GAN = False,GP= False,alpha=0.2):
"""
@param inputs_img: 输入图片,tensor类型
@param reuse:判别器复用
@param GP: 使用WGAN-GP时关闭BN
@param alpha: Leaky ReLU系数
"""
with tf.variable_scope("discriminator", reuse=reuse):
layer1 = tf.layers.conv2d(inputs_img, 128, 3, strides=2, padding='same')
if GP is False:
layer1 = tf.layers.batch_normalization(layer1, training=True)
layer1 = tf.nn.leaky_relu(layer1,alpha=alpha)
layer2 = tf.layers.conv2d(layer1, 256, 3, strides=2, padding='same')
if GP is False:
layer2 = tf.layers.batch_normalization(layer2, training=True)
layer2 = tf.nn.leaky_relu(layer2, alpha=alpha)
layer3 = tf.layers.conv2d(layer2, 512, 3, strides=2, padding='same')
if GP is False:
layer3 = tf.layers.batch_normalization(layer3, training=True)
layer3 = tf.nn.leaky_relu(layer3, alpha=alpha)
layer3 = tf.reshape(layer3, [-1, 4*4* 512])
logits = tf.layers.dense(layer3, 1)
"WGAN:去除sigmoid"
if GAN:
outputs = None
else:
outputs = tf.sigmoid(logits)
return logits, outputs
inputs_real = tf.placeholder(tf.float32, [None, real_shape[1], real_shape[2], real_shape[3]], name='inputs_real')
inputs_noise = tf.placeholder(tf.float32, [None, noise_size], name='inputs_noise')
_,g_outputs = Generator_DC_32x32(inputs_noise, real_shape[3], is_train=True)
_,g_test = Generator_DC_32x32(inputs_noise, real_shape[3], is_train=False)
'WGAN-GP:判别器废除批归一化'
d_logits_real, _ = Discriminator_DC_32x32(inputs_real,GAN=True,GP=True)
d_logits_fake, _ = Discriminator_DC_32x32(g_outputs,GAN=True,GP=True,reuse=True)
"WGAN:损失函数去log,采用Wasserstein距离形式"
g_loss = tf.reduce_mean(-d_logits_fake)
d_loss = tf.reduce_mean(d_logits_fake - d_logits_real)
'WGAN-GP:加入判别器梯度惩罚项'
alpha_dist = tf.contrib.distributions.Uniform(low=0., high=1.)
alpha = alpha_dist.sample((batch_size, 1, 1, 1))
interpolated = inputs_real + alpha*(g_outputs-inputs_real)
inte_logit,_ = Discriminator_DC_32x32(interpolated, GAN=True,GP=True,reuse=True)
gradients = tf.gradients(inte_logit, [interpolated,])[0]
grad_l2 = tf.sqrt(tf.reduce_sum(tf.square(gradients), axis=[1,2,3]))
gradient_penalty = tf.reduce_mean((grad_l2-1)**2)
d_loss+=gradient_penalty*lam
train_vars = tf.trainable_variables()
g_vars = [var for var in train_vars if var.name.startswith("generator")]
d_vars = [var for var in train_vars if var.name.startswith("discriminator")]
with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
g_train_opt = tf.train.AdamOptimizer(learning_rate=learning_rate, beta1=beta1, beta2=beta2).minimize(g_loss, var_list=g_vars)
d_train_opt = tf.train.AdamOptimizer(learning_rate=learning_rate, beta1=beta1, beta2=beta2).minimize(d_loss, var_list=d_vars)
[data,label] = TFRecordTools.ReadFromTFRecord(sameName= r'.\TFR\class7-*',isShuffle= False,datatype= tf.float64,
labeltype= tf.int32,isMultithreading= True)
[data_batch,label_batch] = TFRecordTools.DataBatch(data,label,dataSize= 32*32*3,labelSize= 1,
isShuffle= True,batchSize= 64)
GenLog = []
losses = []
saver = tf.train.Saver(var_list=[var for var in tf.trainable_variables()
if var.name.startswith("generator")],max_to_keep=5)
def batch_preprocess(data_batch):
batch = sess.run(data_batch)
batch_images = np.reshape(batch, [-1, 3, 32, 32]).transpose((0, 2, 3, 1))
batch_images = batch_images * 2 - 1
return batch_images
def GEN_DIR():
import os
if not os.path.isdir('ckpt'):
print('文件夹ckpt未创建,现在在当前目录下创建..')
os.mkdir('ckpt')
if not os.path.isdir('trainLog'):
print('文件夹ckpt未创建,现在在当前目录下创建..')
os.mkdir('trainLog')
with tf.Session() as sess:
GEN_DIR()
init = (tf.global_variables_initializer(), tf.local_variables_initializer())
sess.run(init)
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
time_start = time.time()
for steps in range(max_iters):
steps += 1
if steps < 25 or steps % 500 == 0:
critic_num = 100
else:
critic_num = CRITIC_NUM
for i in range(CRITIC_NUM):
batch_images = batch_preprocess(data_batch)
batch_noise = np.random.normal(size=(batch_size, noise_size))
_ = sess.run(d_train_opt, feed_dict={inputs_real: batch_images,
inputs_noise: batch_noise})
batch_images = batch_preprocess(data_batch)
batch_noise = np.random.normal(size=(batch_size, noise_size))
_ = sess.run(g_train_opt, feed_dict={inputs_real: batch_images,
inputs_noise: batch_noise})
if steps % 5 == 1:
train_loss_d = d_loss.eval({inputs_real: batch_images,
inputs_noise: batch_noise})
train_loss_g = g_loss.eval({inputs_real: batch_images,
inputs_noise: batch_noise})
losses.append([train_loss_d, train_loss_g,steps])
batch_noise = np.random.normal(size=(batch_size, noise_size))
gen_samples = sess.run(g_test, feed_dict={inputs_noise: batch_noise})
genLog = (gen_samples[0:11] + 1) / 2
GenLog.append(genLog)
print('step {}...'.format(steps),
"Discriminator Loss: {:.4f}...".format(train_loss_d),
"Generator Loss: {:.4f}...".format(train_loss_g))
if steps % 300 ==0:
saver.save(sess, './ckpt/generator.ckpt', global_step=steps)
coord.request_stop()
coord.join(threads)
time_end = time.time()
print('迭代结束,耗时:%.2f秒'%(time_end-time_start))
with open('./trainLog/loss_variation.loss', 'wb') as l:
losses = np.array(losses)
pickle.dump(losses,l)
print('保存loss信息..')
with open('./trainLog/GenLog.log', 'wb') as g:
pickle.dump(GenLog, g)
print('保存GenLog信息..')
结果展示
最后一次生成样本

训练期间生成日志

损失函数

验证生成器
