谈到生成对抗网络,我们首先想到的是Goodfellow的开山之作:Generative Adversarial Nets。今天,我们就来谈谈这篇文章。针对一个估计数据分布的问题,当模型的类别已知,我们一般采用极大似然方法进行估计。然而,当模型的类别未知或数据分布过于莫杂时,我们如何近似得到数据的俄分布呢?我想,对抗网络的提出给了我们一些思路。
生成对抗网络,由两个网络组成,即生成器和判别器,在Goodfellow这篇文章里,两者都是由感知器组成。生成器用来建立满足一定分布的随机噪声和目标分布的映射关系,判别器用来区别实际数据分布和生成器产生的数据分布。在训练的过程中,交替迭代训练生成器和判别器,使得生成器产生的数据分布逼近真实数据的分布,欺骗判别器;判别器提升两个数据分布的判别能力。最终达到纳什均衡,使得判别器无法判断两个分布的真伪。
from __future__ import print_function
from six.moves import xrange
import tensorflow.contrib.slim as slim
import os
import tensorflow as tf
import numpy as np
import tensorflow.contrib.layers as ly
from load_svhn import load_svhn
from tensorflow.examples.tutorials.mnist import input_data
def lrelu(x, leak=0.3, name="lrelu"):
with tf.variable_scope(name):
f1 = 0.5 * (1 + leak)
f2 = 0.5 * (1 - leak)
return f1 * x + f2 * abs(x)
batch_size = 64
z_dim = 128
learning_rate_ger = 5e-5
learning_rate_dis = 5e-5
device = '/gpu:0'
# img size
s = 32
# update Citers times of critic in one iter(unless i < 25 or i % 500 == 0, i is iterstep)
Citers = 5
# the upper bound and lower bound of parameters in critic
clamp_lower = -0.01
clamp_upper = 0.01
# whether to use mlp or dcgan stucture
is_mlp = False
# whether to use adam for parameter update, if the flag is set False, use tf.train.RMSPropOptimizer
# as recommended in paper
is_adam = False
# whether to use SVHN or MNIST, set false and MNIST is used
is_svhn = False
channel = 3 if is_svhn is True else 1
s2, s4, s8, s16 =\
int(s / 2), int(s / 4), int(s / 8), int(s / 16)
# hidden layer size if mlp is chosen, ignore if otherwise
ngf = 64
ndf = 64
# directory to store log, including loss and grad_norm of generator and critic
log_dir = './log_wgan'
ckpt_dir = './ckpt_wgan'
if not os.path.exists(ckpt_dir):
os.makedirs(ckpt_dir)
# max iter step, note the one step indicates that a Citers updates of critic and one update of generator
max_iter_step = 20000
def generator_mlp(z):
train = ly.fully_connected(
z, 4 * 4 * 512, activation_fn=lrelu, normalizer_fn=ly.batch_norm)
train = ly.fully_connected(
train, ngf, activation_fn=lrelu, normalizer_fn=ly.batch_norm)
train = ly.fully_connected(
train, ngf, activation_fn=lrelu, normalizer_fn=ly.batch_norm)
train = ly.fully_connected(
train, s*s*channel, activation_fn=tf.nn.tanh, normalizer_fn=ly.batch_norm)
train = tf.reshape(train, tf.stack([batch_size, s, s, channel]))
return train
def critic_mlp(img, reuse=False):
with tf.variable_scope('critic') as scope:
if reuse:
scope.reuse_variables()
size = 64
img = ly.fully_connected(tf.reshape(
img, [batch_size, -1]), ngf, activation_fn=tf.nn.relu)
img = ly.fully_connected(img, ngf,
activation_fn=tf.nn.relu)
img = ly.fully_connected(img, ngf,
activation_fn=tf.nn.relu)
logit = ly.fully_connected(img, 1, activation_fn=None)
return logit
def build_graph():
z = tf.placeholder(tf.float32, shape=(batch_size, z_dim))
generator = generator_mlp if is_mlp else generator_conv
critic = critic_mlp if is_mlp else critic_conv
with tf.variable_scope('generator'):
train = generator(z)
real_data = tf.placeholder(
dtype=tf.float32, shape=(batch_size, 32, 32, channel))
true_logit = critic(real_data)
fake_logit = critic(train, reuse=True)
c_loss = tf.reduce_mean(fake_logit - true_logit)
g_loss = tf.reduce_mean(-fake_logit)
g_loss_sum = tf.summary.scalar("g_loss", g_loss)
c_loss_sum = tf.summary.scalar("c_loss", c_loss)
img_sum = tf.summary.image("img", train, max_outputs=10)
theta_g = tf.get_collection(
tf.GraphKeys.TRAINABLE_VARIABLES, scope='generator')
theta_c = tf.get_collection(
tf.GraphKeys.TRAINABLE_VARIABLES, scope='critic')
counter_g = tf.Variable(trainable=False, initial_value=0, dtype=tf.int32)
opt_g = ly.optimize_loss(loss=g_loss, learning_rate=learning_rate_ger,
optimizer=tf.train.AdamOptimizer if is_adam is True else tf.train.RMSPropOptimizer,
variables=theta_g, global_step=counter_g,
summaries = 'gradient_norm')
counter_c = tf.Variable(trainable=False, initial_value=0, dtype=tf.int32)
opt_c = ly.optimize_loss(loss=c_loss, learning_rate=learning_rate_dis,
optimizer=tf.train.AdamOptimizer if is_adam is True else tf.train.RMSPropOptimizer,
variables=theta_c, global_step=counter_c,
summaries = 'gradient_norm')
clipped_var_c = [tf.assign(var, tf.clip_by_value(var, clamp_lower, clamp_upper)) for var in theta_c]
# merge the clip operations on critic variables
with tf.control_dependencies([opt_c]):
opt_c = tf.tuple(clipped_var_c)
return opt_g, opt_c, z, real_data
def main():
if is_svhn is True:
dataset = load_svhn()
else:
dataset = input_data.read_data_sets('MNIST_data', one_hot=True)
with tf.device(device):
opt_g, opt_c, z, real_data = build_graph()
merged_all = tf.summary.merge_all()
saver = tf.train.Saver()
config = tf.ConfigProto(allow_soft_placement=True, log_device_placement=True)
config.gpu_options.allow_growth = True
config.gpu_options.per_process_gpu_memory_fraction = 0.8
def next_feed_dict():
train_img = dataset.train.next_batch(batch_size)[0]
train_img = 2*train_img-1
if is_svhn is not True:
train_img = np.reshape(train_img, (-1, 28, 28))
npad = ((0, 0), (2, 2), (2, 2))
train_img = np.pad(train_img, pad_width=npad,
mode='constant', constant_values=-1)
train_img = np.expand_dims(train_img, -1)
batch_z = np.random.normal(0, 1, [batch_size, z_dim]) \
.astype(np.float32)
feed_dict = {real_data: train_img, z: batch_z}
return feed_dict
with tf.Session(config=config) as sess:
sess.run(tf.global_variables_initializer())
summary_writer = tf.summary.FileWriter(log_dir, sess.graph)
for i in range(max_iter_step):
if i < 25 or i % 500 == 0:
citers = 100
else:
citers = Citers
for j in range(citers):
feed_dict = next_feed_dict()
if i % 100 == 99 and j == 0:
run_options = tf.RunOptions(
trace_level=tf.RunOptions.FULL_TRACE)
run_metadata = tf.RunMetadata()
_, merged = sess.run([opt_c, merged_all], feed_dict=feed_dict,
options=run_options, run_metadata=run_metadata)
summary_writer.add_summary(merged, i)
summary_writer.add_run_metadata(
run_metadata, 'critic_metadata {}'.format(i), i)
else:
sess.run(opt_c, feed_dict=feed_dict)
feed_dict = next_feed_dict()
if i % 100 == 99:
_, merged = sess.run([opt_g, merged_all], feed_dict=feed_dict,
options=run_options, run_metadata=run_metadata)
summary_writer.add_summary(merged, i)
summary_writer.add_run_metadata(
run_metadata, 'generator_metadata {}'.format(i), i)
else:
sess.run(opt_g, feed_dict=feed_dict)
if i % 1000 == 999:
saver.save(sess, os.path.join(
ckpt_dir, "model.ckpt"), global_step=i)