GAN生成手写字体识别

这篇文我主要是利用GAN生成手写字体,原理和实现方法和之前的GAN生成抛物线是一样的点击打开链接,我们直接看代码。

首先我是定义了一个可视化的函数

import matplotlib.pyplot as plt
def vis_img(batch_size,samples):
    fig,axes = plt.subplots(figsize=(7,7),nrows=8,ncols=8,sharey=True,sharex=True)

    for ax,img in zip(axes.flatten(),samples[batch_size]):


        ax.xaxis.set_visible(False)
        ax.yaxis.set_visible(False)
        im = ax.imshow(img.reshape((28, 28)), cmap='Greys_r')
    plt.show()
    return fig, axes
下面就是实现方法:

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import numpy as np
#from utils import vis_img

mnist = input_data.read_data_sets('./data/mnist',one_hot=True)

def generator(inputs,name,reuse=False):
    # 输入值
    # name 表示scope的name
    # reuse表示是否重用变量
    with tf.variable_scope(name,reuse=reuse) as scope:

        fc1 = tf.layers.dense(inputs,units=128,activation=None)
        #bn1 = tf.layers.batch_normalization(fc1)

        #ac1= tf.nn.relu(bn1)
        ac1 = tf.maximum(0.01*fc1,fc1)

        fc2 = tf.layers.dense(ac1, units=256,activation=None)
        #bn2 = tf.layers.batch_normalization(fc2)
        #ac2 = tf.nn.relu(bn2)
        ac2 = tf.maximum(0.01 * fc2, fc2)
        # 这个地方不需要激活层,
        fc3 = tf.layers.dense(ac2, units=784,activation=tf.nn.tanh)
        return fc3
def discriminator(inputs,name,alpha=0.01,reuse=False):

    with tf.variable_scope(name,reuse=reuse):

        fc1 = tf.layers.dense(inputs,256,activation=None)
        ac1 = tf.maximum(alpha * fc1, fc1)

        fc2 = tf.layers.dense(ac1, 256, activation=None)
        ac2 = tf.maximum(alpha * fc2, fc2)

        logits = tf.layers.dense(ac2, 2, activation=None)
        out = tf.nn.sigmoid(logits)
        return out,logits
epochs = 100
lr = 0.002
batch_size = 64
gen_szie = 100
with tf.name_scope('gen_inp') as scope:
    gen_inp = tf.placeholder(dtype=tf.float32,shape=[None,gen_szie],name='gen_inp')
with tf.name_scope('real_inp') as scope:
    real_inp = tf.placeholder(dtype=tf.float32,shape=[None,784],name='real_inp')

gen_out = generator(gen_inp,'generator',reuse=False)

real_out,real_logits = discriminator(real_inp,name='discriminator',alpha=0.01,reuse=False)
fake_out,fake_logits = discriminator(gen_out,name='discriminator',alpha=0.01,reuse=True)




with tf.name_scope('metrics') as scope:

    loss_g = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.ones_like(fake_logits),
                                                                    logits=fake_logits))
    loss_d_g = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.zeros_like(real_logits),
                                                                    logits=fake_logits))
    loss_d_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.ones_like(real_logits)*0.99,
                                                                    logits=real_logits))
    loss_d = loss_d_g + loss_d_real

    var_list_g = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,scope='generator')
    var_list_d = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='discriminator')

    g_optimizer = tf.train.AdamOptimizer(lr).minimize(loss_g,var_list=var_list_g)
    d_optimizer = tf.train.AdamOptimizer(lr).minimize(loss_d, var_list=var_list_d)
sum_g = tf.summary.scalar('g_loss',loss_g)
sum_d = tf.summary.scalar('g_loss',loss_g)
mer_g = tf.summary.merge([sum_g])
mer_d = tf.summary.merge([sum_d])
with tf.Session() as sess:

    sess.run(tf.global_variables_initializer())
    writer = tf.summary.FileWriter('./graph/mnist',sess.graph)
    saver = tf.train.Saver()
    n_batchs = mnist.train.num_examples // batch_size

    for epoch in range(epochs):
        
        total_loss_d = 0
        total_loss_g = 0

        for ii in range(n_batchs):


            xs_real,ys = mnist.train.next_batch(batch_size)
            xs_real = xs_real*2 - 1
            xs_gen = np.random.uniform(-1,1,[batch_size,gen_szie])

            _,train_loss_d,summ_d = sess.run([d_optimizer,loss_d,mer_d],feed_dict={gen_inp:xs_gen,real_inp:xs_real})
            writer.add_summary(summ_d)

            _, train_loss_g,summ_g = sess.run([g_optimizer, loss_g,mer_g], feed_dict={gen_inp: xs_gen, real_inp: xs_real})
            writer.add_summary(summ_g)
            total_loss_d += train_loss_d
            total_loss_g += train_loss_g


        if epoch % 10 == 0:

            print('epoch {},loss_g={}'.format(epoch,total_loss_g/n_batchs))
            print('epoch {},loss_d={}'.format(epoch, total_loss_d/n_batchs))
            xs_gen = np.random.uniform(-1, 1, [batch_size, gen_szie])

            gen_img = sess.run(gen_out,feed_dict={gen_inp:xs_gen})
     

            vis_img(-1,[gen_img])

    writer.close()
    saver.save(sess, "./checkpoints/mnist")
然后我们看一下效果:











可以看出效果还可以。

另外,我还实验了,就是我代码注释部分,generator里面,我使用bn层和relu层,发现效果一点也不好。一直是一堆麻子。

然后我在使用bn层加Leaky ReLU,效果也很好。

最后我有把bn层去掉,感觉影响不是很大,效果还可以。



  • 0
    点赞
  • 11
    收藏
    觉得还不错? 一键收藏
  • 2
    评论
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值