这篇文我主要是利用GAN生成手写字体,原理和实现方法和之前的GAN生成抛物线是一样的点击打开链接,我们直接看代码。
首先我是定义了一个可视化的函数
import matplotlib.pyplot as plt
def vis_img(batch_size,samples):
fig,axes = plt.subplots(figsize=(7,7),nrows=8,ncols=8,sharey=True,sharex=True)
for ax,img in zip(axes.flatten(),samples[batch_size]):
ax.xaxis.set_visible(False)
ax.yaxis.set_visible(False)
im = ax.imshow(img.reshape((28, 28)), cmap='Greys_r')
plt.show()
return fig, axes
下面就是实现方法:
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import numpy as np
#from utils import vis_img
mnist = input_data.read_data_sets('./data/mnist',one_hot=True)
def generator(inputs,name,reuse=False):
# 输入值
# name 表示scope的name
# reuse表示是否重用变量
with tf.variable_scope(name,reuse=reuse) as scope:
fc1 = tf.layers.dense(inputs,units=128,activation=None)
#bn1 = tf.layers.batch_normalization(fc1)
#ac1= tf.nn.relu(bn1)
ac1 = tf.maximum(0.01*fc1,fc1)
fc2 = tf.layers.dense(ac1, units=256,activation=None)
#bn2 = tf.layers.batch_normalization(fc2)
#ac2 = tf.nn.relu(bn2)
ac2 = tf.maximum(0.01 * fc2, fc2)
# 这个地方不需要激活层,
fc3 = tf.layers.dense(ac2, units=784,activation=tf.nn.tanh)
return fc3
def discriminator(inputs,name,alpha=0.01,reuse=False):
with tf.variable_scope(name,reuse=reuse):
fc1 = tf.layers.dense(inputs,256,activation=None)
ac1 = tf.maximum(alpha * fc1, fc1)
fc2 = tf.layers.dense(ac1, 256, activation=None)
ac2 = tf.maximum(alpha * fc2, fc2)
logits = tf.layers.dense(ac2, 2, activation=None)
out = tf.nn.sigmoid(logits)
return out,logits
epochs = 100
lr = 0.002
batch_size = 64
gen_szie = 100
with tf.name_scope('gen_inp') as scope:
gen_inp = tf.placeholder(dtype=tf.float32,shape=[None,gen_szie],name='gen_inp')
with tf.name_scope('real_inp') as scope:
real_inp = tf.placeholder(dtype=tf.float32,shape=[None,784],name='real_inp')
gen_out = generator(gen_inp,'generator',reuse=False)
real_out,real_logits = discriminator(real_inp,name='discriminator',alpha=0.01,reuse=False)
fake_out,fake_logits = discriminator(gen_out,name='discriminator',alpha=0.01,reuse=True)
with tf.name_scope('metrics') as scope:
loss_g = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.ones_like(fake_logits),
logits=fake_logits))
loss_d_g = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.zeros_like(real_logits),
logits=fake_logits))
loss_d_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.ones_like(real_logits)*0.99,
logits=real_logits))
loss_d = loss_d_g + loss_d_real
var_list_g = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,scope='generator')
var_list_d = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='discriminator')
g_optimizer = tf.train.AdamOptimizer(lr).minimize(loss_g,var_list=var_list_g)
d_optimizer = tf.train.AdamOptimizer(lr).minimize(loss_d, var_list=var_list_d)
sum_g = tf.summary.scalar('g_loss',loss_g)
sum_d = tf.summary.scalar('g_loss',loss_g)
mer_g = tf.summary.merge([sum_g])
mer_d = tf.summary.merge([sum_d])
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
writer = tf.summary.FileWriter('./graph/mnist',sess.graph)
saver = tf.train.Saver()
n_batchs = mnist.train.num_examples // batch_size
for epoch in range(epochs):
total_loss_d = 0
total_loss_g = 0
for ii in range(n_batchs):
xs_real,ys = mnist.train.next_batch(batch_size)
xs_real = xs_real*2 - 1
xs_gen = np.random.uniform(-1,1,[batch_size,gen_szie])
_,train_loss_d,summ_d = sess.run([d_optimizer,loss_d,mer_d],feed_dict={gen_inp:xs_gen,real_inp:xs_real})
writer.add_summary(summ_d)
_, train_loss_g,summ_g = sess.run([g_optimizer, loss_g,mer_g], feed_dict={gen_inp: xs_gen, real_inp: xs_real})
writer.add_summary(summ_g)
total_loss_d += train_loss_d
total_loss_g += train_loss_g
if epoch % 10 == 0:
print('epoch {},loss_g={}'.format(epoch,total_loss_g/n_batchs))
print('epoch {},loss_d={}'.format(epoch, total_loss_d/n_batchs))
xs_gen = np.random.uniform(-1, 1, [batch_size, gen_szie])
gen_img = sess.run(gen_out,feed_dict={gen_inp:xs_gen})
vis_img(-1,[gen_img])
writer.close()
saver.save(sess, "./checkpoints/mnist")
然后我们看一下效果:
可以看出效果还可以。
另外,我还实验了,就是我代码注释部分,generator里面,我使用bn层和relu层,发现效果一点也不好。一直是一堆麻子。
然后我在使用bn层加Leaky ReLU,效果也很好。
最后我有把bn层去掉,感觉影响不是很大,效果还可以。