1 整体架构
2 执行脚本
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import os
mb_size = 32
'''X_dim: 输入图像数据尺寸.'''
X_dim = 784
'''noise_dim: 噪音数据尺寸(用于生成"假"图像).'''
noise_dim = 64
'''hidden dim: 隐藏层维度.'''
hidden_dim = 128
'''lr: 学习率.'''
lr = 1e-3
d_steps = 3
LOG_DIR = "./logs"
if not os.path.exists(LOG_DIR):
os.makedirs(LOG_DIR)
'''MNIST数据'''
mnist = input_data.read_data_sets('../MNIST_data', one_hot=True)
def extract_data():
'''数据提取测试
返回:
images:图像矩阵列表
labels:图像标签列表
'''
images = mnist.train.images
labels = mnist.train.labels
return images, labels
def plot(samples):
'''绘制生成的图像.
参数:
samples: 生成图像的矩阵数据.
返回:
fig: 绘图框对象.
'''
fig = plt.figure(figsize=(4, 4))
gs = gridspec.GridSpec(4, 4)
gs.update(wspace=0.05, hspace=0.05)
for i, sample in enumerate(samples):
ax = plt.subplot(gs[i])
plt.axis('off')
ax.set_xticklabels([])
ax.set_yticklabels([])
ax.set_aspect('equal')
'''cmap: 设置图像色阶,Grey_r为黑白,否则生成彩色字体.'''
plt.imshow(sample.reshape(28, 28), cmap='Greys_r')
return fig
def xavier_init(size):
'''初始化权重和偏置.
参数:
size: 指定的数据尺寸.
返回:
指定尺寸的随机数据.
'''
in_dim = size[0]
xavier_stddev = 1. / tf.sqrt(in_dim / 2.)
return tf.random_normal(shape=size, stddev=xavier_stddev)
def log(x):
'''log方法计算数据.
参数 x: 输入数据.
返回:
log方法计算后的结果.
'''
return tf.log(x + 1e-8)
'''真实图像矩阵数据.'''
X = tf.placeholder(tf.float32, shape=[None, X_dim])
'''噪声矩阵数据.'''
z = tf.placeholder(tf.float32, shape=[None, noise_dim])
'''判别网络参数.'''
D_W1 = tf.Variable(xavier_init([X_dim + noise_dim, hidden_dim]))
D_b1 = tf.Variable(tf.zeros(shape=[hidden_dim]))
D_W2 = tf.Variable(xavier_init([hidden_dim, 1]))
D_b2 = tf.Variable(tf.zeros(shape=[1]))
'''图像预处理网络参数.'''
Q_W1 = tf.Variable(xavier_init([X_dim, hidden_dim]))
Q_b1 = tf.Variable(tf.zeros(shape=[hidden_dim]))
Q_W2 = tf.Variable(xavier_init([hidden_dim, noise_dim]))
Q_b2 = tf.Variable(tf.zeros(shape=[noise_dim]))
'''生成图像网络参数.'''
P_W1 = tf.Variable(xavier_init([noise_dim, hidden_dim]))
P_b1 = tf.Variable(tf.zeros(shape=[hidden_dim]))
P_W2 = tf.Variable(xavier_init([hidden_dim, X_dim]))
P_b2 = tf.Variable(tf.zeros(shape=[X_dim]))
'''变量列表.'''
theta_G = [Q_W1, Q_W2, Q_b1, Q_b2, P_W1, P_W2, P_b1, P_b2]
theta_D = [D_W1, D_W2, D_b1, D_b2]
def sample_z(m, n):
'''生成噪声数据,为生成图像网络提供输入.
参数:
m: 矩阵行数
n: 矩阵列数
返回:
指定维度的随机数据
'''
return np.random.uniform(-1., 1., size=[m, n])
def process_real_image(X):
'''图像预处理.
参数:
X: 真实图像矩阵数据.
返回:
输出图像尺寸: batch*64
'''
h = tf.nn.relu(tf.matmul(X, Q_W1) + Q_b1)
h = tf.matmul(h, Q_W2) + Q_b2
'''batch x 64'''
return h
def generate_image(z):
'''生成图像计算网络.
参数:
z: 噪声输入矩阵
返回:
经过sigmoid非线性处理的数据
'''
h = tf.nn.relu(tf.matmul(z, P_W1) + P_b1)
h = tf.matmul(h, P_W2) + P_b2
'''batch x 784'''
return tf.nn.sigmoid(h)
def discriminate_image(X, z):
'''判别网络计算.
参数
X: 真实图像矩阵数据
z: 图像噪声
返回:
sigmoid非线性处理的数据
'''
inputs = tf.concat([X, z], axis=1)
h = tf.nn.relu(tf.matmul(inputs, D_W1) + D_b1)
'''batch x 1'''
return tf.nn.sigmoid(tf.matmul(h, D_W2) + D_b2)
'''原图预处理生成 batch*64 图像矩阵.'''
z_hat = process_real_image(X)
'''噪声生成图像,处理生成 batch*784 图像矩阵.'''
X_hat = generate_image(z)
'''判别原图和原图生成的图像为同一张图片的概率.
判别器能力强:D_enc值越大,D_gen越小
'''
D_enc = discriminate_image(X, z_hat)
'''判别生成的图像和噪声为同一张图片的概率.
生成能力强:D_gen越大,极限为:D_enc=D_gen
'''
D_gen = discriminate_image(X_hat, z)
D_loss = -tf.reduce_mean(log(D_enc) + log(1 - D_gen))
tf.summary.scalar("Discriminator", D_loss)
'''判别网络损失:判别能力强,总体D_loss越大越好
D_enc:原图和原图生成图为同一张图片的概率(希望尽可能大)
D_gen:生成图像和噪声为同一张图片的概率(希望尽可能小)
1-D_gen:生成图像和噪声不是同一张图片的概率
'''
G_loss = -tf.reduce_mean(log(D_gen) + log(1 - D_enc))
tf.summary.scalar("Generator", G_loss)
'''生成网络损失:生成能力强(骗过识别网络),总体G_loss越大越好
D_enc:原图和生成图为同一张图片的概率(希望尽可能小)
D_gen:生成图像和噪声为同一张图片的概率(希望尽可能大)
1-D_enc:原图和生成图像不是同一张图片的概率
'''
'''通过上述分析:D_loss和G_loss形成了对抗
都想使自己的概率最大化
'''
'''迭代优化'''
D_solver = (tf.train.AdamOptimizer(learning_rate=lr)
.minimize(D_loss, var_list=theta_D))
'''Optimize generate network loss.'''
G_solver = (tf.train.AdamOptimizer(learning_rate=lr)
.minimize(G_loss, var_list=theta_G))
summary_op = tf.summary.merge(tf.get_collection(tf.GraphKeys.SUMMARIES))
def train():
saver = tf.train.Saver()
with tf.Session() as sess:
summary_write = tf.summary.FileWriter(LOG_DIR, sess.graph)
sess.run(tf.global_variables_initializer())
if not os.path.exists('out/'):
os.makedirs('out/')
if not os.path.exists('models/'):
os.makedirs("models/")
i = 0
for it in range(10001):
'''Extract image data from datasets.'''
X_mb, _ = mnist.train.next_batch(mb_size)
# print("image data: {}".format(X_mb))
'''Generate noise.'''
z_mb = sample_z(mb_size, noise_dim)
_, D_loss_curr, summary = sess.run(
[D_solver, D_loss, summary_op], feed_dict={X: X_mb, z: z_mb}
)
_, G_loss_curr = sess.run(
[G_solver, G_loss], feed_dict={X: X_mb, z: z_mb}
)
if it % 1000 == 0:
print('Iter: {}; D_loss: {:.4}; G_loss: {:.4}'
.format(it, D_loss_curr, G_loss_curr))
samples = sess.run(X_hat, feed_dict={z: sample_z(16, noise_dim)})
'''Save evaluate results.'''
fig = plot(samples)
plt.savefig('out/{}.png'
.format(str(i).zfill(3)), bbox_inches='tight')
i += 1
plt.close(fig)
saver.save(sess, "./models/gan_test.ckpt")
summary_write.add_summary(summary, it)
def line_draw():
print("------------------------")
if __name__ == "__main__":
train()
3 训练结果
鉴别器损失值最终趋于0.5,使生成图像网络和鉴别图像网络平分秋色.
4 总结
(1) 对抗网络共有三个网络,生成图像网络,判别图形网络,图像处理网络;其中,图像处理网络用于生成中间图像加入到判断的原始图像中,用于判断图像真伪,生成图像网络生成图像,判别网络判断生成的图像是否为原始图像;
(2) 对抗的形成:都想争第一,鉴别器判断原始图像,希望输出概率最大
m
a
x
D
(
X
)
maxD(X)
maxD(X),最好为1,生成器生成的图像经过鉴别器判断,希望输出概率最大
m
a
x
D
(
G
(
z
)
)
maxD(G(z))
maxD(G(z)),最好也为1,这样生成图像的网络与鉴别器网络形成了竞争,最后各退一步,最好的状态是各取0.5;
(3) 训练网络:通过对抗原理:
【原始图像提升随机梯度更新鉴别器】
∇
θ
d
1
m
∑
i
=
1
m
[
l
o
g
D
(
x
i
)
+
l
o
g
(
1
−
D
(
G
(
z
i
)
)
)
]
\nabla_{\theta_{d}}\frac{1}{m}\sum_{i=1}^{m}[logD(x^i)+log(1-D(G(z^i)))]
∇θdm1i=1∑m[logD(xi)+log(1−D(G(zi)))]
【降低随机梯度更新图像生成器】
∇
θ
g
1
m
∑
i
=
1
m
l
o
g
(
1
−
D
(
G
(
z
i
)
)
)
\nabla_{\theta_{g}}\frac{1}{m}\sum_{i=1}^{m}log(1-D(G(z^i)))
∇θgm1i=1∑mlog(1−D(G(zi)))
定义损失函数,鉴别器损失:
D
l
o
s
s
=
−
t
f
.
r
e
d
u
c
e
(
l
o
g
(
D
e
n
c
)
+
l
o
g
(
1
−
D
g
e
n
)
)
D_{loss} = -tf.reduce(log(D_{enc})+log(1-D_{gen}))
Dloss=−tf.reduce(log(Denc)+log(1−Dgen))
生成器损失:
G
l
o
s
s
=
−
t
f
.
r
e
d
u
c
e
(
l
o
g
(
D
g
e
n
)
+
l
o
g
(
1
−
D
e
n
c
)
)
G_{loss} = -tf.reduce(log(D_{gen})+log(1-D_{enc}))
Gloss=−tf.reduce(log(Dgen)+log(1−Denc))
这里的小窍门就是在损失中形成对抗,鉴别器损失中原始图与生成图像鉴别对抗;生成器损失中生成图与原始图对抗,因为单独使用论文中的生成器公式,效果不好,所以改变为生成器也使用对抗,达到较好生成效果。