(二)GAN生成图像代码解析

1 整体架构

在这里插入图片描述

图1.0 GAN整体架构

2 执行脚本

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import os

mb_size = 32
'''X_dim: 输入图像数据尺寸.'''
X_dim = 784
'''noise_dim: 噪音数据尺寸(用于生成"假"图像).'''
noise_dim = 64
'''hidden dim: 隐藏层维度.'''
hidden_dim = 128
'''lr: 学习率.'''
lr = 1e-3
d_steps = 3
LOG_DIR = "./logs"
if not os.path.exists(LOG_DIR):
	os.makedirs(LOG_DIR)
'''MNIST数据'''
mnist = input_data.read_data_sets('../MNIST_data', one_hot=True)

def extract_data():
	'''数据提取测试
	返回:
	images:图像矩阵列表
	labels:图像标签列表
	'''
    images = mnist.train.images
    labels = mnist.train.labels
    return images, labels

def plot(samples):
    '''绘制生成的图像.
    参数: 
    samples: 生成图像的矩阵数据.
    返回:
    fig: 绘图框对象.
    '''
    fig = plt.figure(figsize=(4, 4))
    gs = gridspec.GridSpec(4, 4)
    gs.update(wspace=0.05, hspace=0.05)

    for i, sample in enumerate(samples):
        ax = plt.subplot(gs[i])
        plt.axis('off')
        ax.set_xticklabels([])
        ax.set_yticklabels([])
        ax.set_aspect('equal')
        '''cmap: 设置图像色阶,Grey_r为黑白,否则生成彩色字体.'''
        plt.imshow(sample.reshape(28, 28), cmap='Greys_r')
    return fig

def xavier_init(size):
    '''初始化权重和偏置.
    参数:
    size: 指定的数据尺寸.
    返回:
    指定尺寸的随机数据.
    '''
    in_dim = size[0]
    xavier_stddev = 1. / tf.sqrt(in_dim / 2.)
    return tf.random_normal(shape=size, stddev=xavier_stddev)

def log(x):
    '''log方法计算数据.
    参数 x: 输入数据.
    返回:
    log方法计算后的结果.
    '''
    return tf.log(x + 1e-8)

'''真实图像矩阵数据.'''
X = tf.placeholder(tf.float32, shape=[None, X_dim])
'''噪声矩阵数据.'''
z = tf.placeholder(tf.float32, shape=[None, noise_dim])
'''判别网络参数.'''
D_W1 = tf.Variable(xavier_init([X_dim + noise_dim, hidden_dim]))
D_b1 = tf.Variable(tf.zeros(shape=[hidden_dim]))
D_W2 = tf.Variable(xavier_init([hidden_dim, 1]))
D_b2 = tf.Variable(tf.zeros(shape=[1]))
'''图像预处理网络参数.'''
Q_W1 = tf.Variable(xavier_init([X_dim, hidden_dim]))
Q_b1 = tf.Variable(tf.zeros(shape=[hidden_dim]))
Q_W2 = tf.Variable(xavier_init([hidden_dim, noise_dim]))
Q_b2 = tf.Variable(tf.zeros(shape=[noise_dim]))
'''生成图像网络参数.'''
P_W1 = tf.Variable(xavier_init([noise_dim, hidden_dim]))
P_b1 = tf.Variable(tf.zeros(shape=[hidden_dim]))
P_W2 = tf.Variable(xavier_init([hidden_dim, X_dim]))
P_b2 = tf.Variable(tf.zeros(shape=[X_dim]))
'''变量列表.'''
theta_G = [Q_W1, Q_W2, Q_b1, Q_b2, P_W1, P_W2, P_b1, P_b2]
theta_D = [D_W1, D_W2, D_b1, D_b2]

def sample_z(m, n):
    '''生成噪声数据,为生成图像网络提供输入.
    参数:
    m: 矩阵行数
    n: 矩阵列数
    返回:
    指定维度的随机数据
    '''
    return np.random.uniform(-1., 1., size=[m, n])

def process_real_image(X):
    '''图像预处理.
    参数:
    X: 真实图像矩阵数据.
    返回:
    输出图像尺寸: batch*64
    '''
    h = tf.nn.relu(tf.matmul(X, Q_W1) + Q_b1)
    h = tf.matmul(h, Q_W2) + Q_b2
    '''batch x 64'''
    return h

def generate_image(z):
    '''生成图像计算网络.
    参数:
    z: 噪声输入矩阵
    返回:
    经过sigmoid非线性处理的数据
    '''
    h = tf.nn.relu(tf.matmul(z, P_W1) + P_b1)
    h = tf.matmul(h, P_W2) + P_b2
    '''batch x 784'''
    return tf.nn.sigmoid(h)


def discriminate_image(X, z):
    '''判别网络计算.
    参数 
    X: 真实图像矩阵数据
    z: 图像噪声
    返回:
    sigmoid非线性处理的数据
    '''
    inputs = tf.concat([X, z], axis=1)
    h = tf.nn.relu(tf.matmul(inputs, D_W1) + D_b1)
    '''batch x 1'''
    return tf.nn.sigmoid(tf.matmul(h, D_W2) + D_b2)

'''原图预处理生成 batch*64 图像矩阵.'''
z_hat = process_real_image(X)
'''噪声生成图像,处理生成 batch*784 图像矩阵.'''
X_hat = generate_image(z)
'''判别原图和原图生成的图像为同一张图片的概率.
判别器能力强:D_enc值越大,D_gen越小
'''
D_enc = discriminate_image(X, z_hat)

'''判别生成的图像和噪声为同一张图片的概率.
生成能力强:D_gen越大,极限为:D_enc=D_gen
'''
D_gen = discriminate_image(X_hat, z)

D_loss = -tf.reduce_mean(log(D_enc) + log(1 - D_gen))
tf.summary.scalar("Discriminator", D_loss)
'''判别网络损失:判别能力强,总体D_loss越大越好
D_enc:原图和原图生成图为同一张图片的概率(希望尽可能大)
D_gen:生成图像和噪声为同一张图片的概率(希望尽可能小)
1-D_gen:生成图像和噪声不是同一张图片的概率
'''

G_loss = -tf.reduce_mean(log(D_gen) + log(1 - D_enc))
tf.summary.scalar("Generator", G_loss)
'''生成网络损失:生成能力强(骗过识别网络),总体G_loss越大越好
D_enc:原图和生成图为同一张图片的概率(希望尽可能小)
D_gen:生成图像和噪声为同一张图片的概率(希望尽可能大)
1-D_enc:原图和生成图像不是同一张图片的概率
'''

'''通过上述分析:D_loss和G_loss形成了对抗
都想使自己的概率最大化
'''
'''迭代优化'''
D_solver = (tf.train.AdamOptimizer(learning_rate=lr)
            .minimize(D_loss, var_list=theta_D))
'''Optimize generate network loss.'''
G_solver = (tf.train.AdamOptimizer(learning_rate=lr)
            .minimize(G_loss, var_list=theta_G))
summary_op = tf.summary.merge(tf.get_collection(tf.GraphKeys.SUMMARIES))
def train():
    saver = tf.train.Saver()
    with tf.Session() as sess:
    	summary_write = tf.summary.FileWriter(LOG_DIR, sess.graph)
        sess.run(tf.global_variables_initializer())
        if not os.path.exists('out/'):
            os.makedirs('out/')
        if not os.path.exists('models/'):
            os.makedirs("models/")
        i = 0
        for it in range(10001):
            '''Extract image data from datasets.'''
            X_mb, _ = mnist.train.next_batch(mb_size)
            # print("image data: {}".format(X_mb))
            '''Generate noise.'''
            z_mb = sample_z(mb_size, noise_dim)

            _, D_loss_curr, summary = sess.run(
                [D_solver, D_loss, summary_op], feed_dict={X: X_mb, z: z_mb}
            )

            _, G_loss_curr = sess.run(
                [G_solver, G_loss], feed_dict={X: X_mb, z: z_mb}
            )

            if it % 1000 == 0:
                print('Iter: {}; D_loss: {:.4}; G_loss: {:.4}'
                      .format(it, D_loss_curr, G_loss_curr))

                samples = sess.run(X_hat, feed_dict={z: sample_z(16, noise_dim)})

                '''Save evaluate results.'''
                fig = plot(samples)
                plt.savefig('out/{}.png'
                            .format(str(i).zfill(3)), bbox_inches='tight')
                i += 1
                plt.close(fig)
            saver.save(sess, "./models/gan_test.ckpt")
            summary_write.add_summary(summary, it)

def line_draw():
    print("------------------------")
if __name__ == "__main__":
    train()

3 训练结果

在这里插入图片描述

图3.1 鉴别器损失

在这里插入图片描述

图3.2 生成器损失

鉴别器损失值最终趋于0.5,使生成图像网络和鉴别图像网络平分秋色.

4 总结

(1) 对抗网络共有三个网络,生成图像网络,判别图形网络,图像处理网络;其中,图像处理网络用于生成中间图像加入到判断的原始图像中,用于判断图像真伪,生成图像网络生成图像,判别网络判断生成的图像是否为原始图像;
(2) 对抗的形成:都想争第一,鉴别器判断原始图像,希望输出概率最大 m a x D ( X ) maxD(X) maxD(X),最好为1,生成器生成的图像经过鉴别器判断,希望输出概率最大 m a x D ( G ( z ) ) maxD(G(z)) maxD(G(z)),最好也为1,这样生成图像的网络与鉴别器网络形成了竞争,最后各退一步,最好的状态是各取0.5;
(3) 训练网络:通过对抗原理:
【原始图像提升随机梯度更新鉴别器】
∇ θ d 1 m ∑ i = 1 m [ l o g D ( x i ) + l o g ( 1 − D ( G ( z i ) ) ) ] \nabla_{\theta_{d}}\frac{1}{m}\sum_{i=1}^{m}[logD(x^i)+log(1-D(G(z^i)))] θdm1i=1m[logD(xi)+log(1D(G(zi)))]
【降低随机梯度更新图像生成器】
∇ θ g 1 m ∑ i = 1 m l o g ( 1 − D ( G ( z i ) ) ) \nabla_{\theta_{g}}\frac{1}{m}\sum_{i=1}^{m}log(1-D(G(z^i))) θgm1i=1mlog(1D(G(zi)))
定义损失函数,鉴别器损失:
D l o s s = − t f . r e d u c e ( l o g ( D e n c ) + l o g ( 1 − D g e n ) ) D_{loss} = -tf.reduce(log(D_{enc})+log(1-D_{gen})) Dloss=tf.reduce(log(Denc)+log(1Dgen))
生成器损失:
G l o s s = − t f . r e d u c e ( l o g ( D g e n ) + l o g ( 1 − D e n c ) ) G_{loss} = -tf.reduce(log(D_{gen})+log(1-D_{enc})) Gloss=tf.reduce(log(Dgen)+log(1Denc))
这里的小窍门就是在损失中形成对抗,鉴别器损失中原始图与生成图像鉴别对抗;生成器损失中生成图与原始图对抗,因为单独使用论文中的生成器公式,效果不好,所以改变为生成器也使用对抗,达到较好生成效果。

  • 4
    点赞
  • 39
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

天然玩家

坚持才能做到极致

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值