实例描述
使用变分自编码模型进行模拟MNIST数据的生成
1.引入库,定义占位符
编码器为两个全连接层:第一个全连接成有784个维度的输入变化为256个维度的输出;第二个全连接层并列连接了两个输出网络(mean与lg_var),每个网络都输出了两个维度的输出。然后将两个输出通过一个公式的计算,输入到以一个2节点为开始的解码部分,接着后面为两个全连接层的解码器。第一层由两维度的输入到256维度的输出,第二层由256个维度的输入到784个维度的输出。
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
from scipy.stats import norm
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("/data/")#, one_hot=True)
n_input = 784
n_hidden_1 = 256
n_hidden_2 = 2
#定义占位符
#输入
x = tf.placeholder(tf.float32, [None, n_input])
zinput = tf.placeholder(tf.float32, [None, n_hidden_2])
zinput是一个占位符,后面通过它输入分布数据,用来生成模拟样本数据。
2.定义学习参数
定义的参数有所变化,mean_w1与mean_b1是生成mean的权重,log_sigma_w1与log_sigma_b1是生成log_sigma的权重。
#定义学习参数
weights = {
'w1': tf.Variable(tf.truncated_normal([n_input, n_hidden_1],
stddev=0.001)),
'b1': tf.Variable(tf.zeros([n_hidden_1])),
'mean_w1': tf.Variable(tf.truncated_normal([n_hidden_1, n_hidden_2],
stddev=0.001)),
'log_sigma_w1': tf.Variable(tf.truncated_normal([n_hidden_1, n_hidden_2],
stddev=0.001)),
'w2': tf.Variable(tf.truncated_normal([n_hidden_2, n_hidden_1],
stddev=0.001)),
'b2': tf.Variable(tf.zeros([n_hidden_1])),
'w3': tf.Variable(tf.truncated_normal([n_hidden_1, n_input],
stddev=0.001)),
'b3': tf.Variable(tf.zeros([n_input])),
'mean_b1': tf.Variable(tf.zeros([n_hidden_2])),
'log_sigma_b1': tf.Variable(tf.zeros([n_hidden_2]))
}
注意w值采用比较小的方差,设置的非常小,是由于算KL离散度时计算的是与标准高斯分布的距离,如果网络初始生成的模型均值方差很大,那么与标准高斯粉嫩不的距离就会非常大,导致模型训练不出来。
3.定义网络结构
网络节点可以按照以下代码进行定义,但在变分解码器中为训练的中间节点赋予了特殊的意义,也就是中间的均值和方差节点,并将他们所代表的数据集向着标准高斯分布数据靠近(也是就是原始数据是样本,标准高斯分布数据是标签)。然后可以使用KL离散度公式,来计算该均值和方差节点代表的数据集合与标准的高斯分布集合间的距离,将这个距离当成误差,让它最小化从而优化网络参数。
这里的方差节点不是真正意义的方差,是取log之后的值,会后tf.exp还原的操作取方差的值,再通过tf.sqrt将其开方得到标准差。用符合标准正太分布的一个数乘以标准差加上均值,就使这个数成为符合(z_mean,sigma)数据分布集合里的一个点(z_mean是指网络生成均值,sigma是指网络生成的z_log_sigma_sq变换后的值)
到此,完成了编码阶段,将原始数据编码输出3个值:
- 一个是表示该数据分布的均值
- 一个是表示数据分布的方差
- 还有一个是得到了该数据分布的一个实际的点z
上面说明的就是z求得的方法
这里变换的知识点是:
假如一个符合高斯分布的数据集的均值、标准差为(m,sigma),这里样本中的值为(z_mean,z_log_sigma_sq),其中的某个点,可以通过一个符合标准高斯分布(0,1)中的点x,通过m+x×sigma的方式转化得到。
但在实际中,无法保证转换后的数据分布符合高斯分布,则可以通过测量输出代表的数据集与标准高斯分布数据集之间的差距,利用神经网络来将其训练成符合高斯分布的数据集。
#第一层
h1=tf.nn.relu(tf.add(tf.matmul(x, weights['w1']), weights['b1']))
#获得解码后的z_mean,z_log_sigma_sq两个并列层
z_mean = tf.add(tf.matmul(h1, weights['mean_w1']), weights['mean_b1'])
z_log_sigma_sq = tf.add(tf.matmul(h1, weights['log_sigma_w1']), weights['log_sigma_b1'])
#定义一个符合高斯分布的样本
eps = tf.random_normal(tf.stack([tf.shape(h1)[0], n_hidden_2]), 0, 1, dtype = tf.float32) #相当于x
#获得变换后的该数据分布中的z点
z =tf.add(z_mean, tf.multiply(tf.sqrt(tf.exp(z_log_sigma_sq)), eps))
#解码
h2=tf.nn.relu( tf.matmul(z, weights['w2'])+ weights['b2'])
reconstruction = tf.matmul(h2, weights['w3'])+ weights['b3']
h2out=tf.nn.relu( tf.matmul(zinput, weights['w2'])+ weights['b2'])
reconstructionout = tf.matmul(h2out, weights['w3'])+ weights['b3']
得到符合原始数据集上的具体点z后,就可以通过神经网络这个点z还原成原始数据reconstruction。
h2out和reconstructionout两个节点不属于训练中的结构,是未来生成指定数据时使用的。
4.构建模型的反向传播
定义损失函数的节点和优化算法op
网络的优化方向有两种:
- 一种是比较生成的数据分布与标准高斯分布的距离,这里使用KL离散度的公式(见latent_loss)
- 另一种是计算生成数据与原始数据间的损失,这里用的是平方差,也可以用交叉熵。
最后将两种损失值放在一起,通过Adam的随机梯度下降算法实现在训练中的优化参数。
#计算重建loss
#计算生成数据与原始数据的损失
reconstr_loss = 0.5 * tf.reduce_sum(tf.pow(tf.subtract(reconstruction, x), 2.0))
#计算生成的数据分布与标准高斯分布的损失
latent_loss = -0.5 * tf.reduce_sum(1 + z_log_sigma_sq - tf.square(z_mean) - tf.exp(z_log_sigma_sq), 1)
cost = tf.reduce_mean(reconstr_loss + latent_loss)
optimizer = tf.train.AdamOptimizer(learning_rate = 0.001).minimize(cost)
5.设置参数,进行训练
training_epochs = 50
batch_size = 128
display_step = 3
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
for epoch in range(training_epochs):
avg_cost = 0.
total_batch = int(mnist.train.num_examples/batch_size)
# 遍历全部数据集
for i in range(total_batch):
batch_xs, batch_ys = mnist.train.next_batch(batch_size)#取数据
# Fit training using batch data
_,c = sess.run([optimizer,cost], feed_dict={x: batch_xs})
#c = autoencoder.partial_fit(batch_xs)
# 显示训练中的详细信息
if epoch % display_step == 0:
print("Epoch:", '%04d' % (epoch + 1), "cost=", "{:.9f}".format(c))
print("完成!")
# 测试
print ("Result:", cost.eval({x: mnist.test.images}))
# 可视化结果
show_num = 10
pred = sess.run(
reconstruction, feed_dict={x: mnist.test.images[:show_num]})
f, a = plt.subplots(2, 10, figsize=(10, 2))
for i in range(show_num):
a[0][i].imshow(np.reshape(mnist.test.images[i], (28, 28)))
a[1][i].imshow(np.reshape(pred[i], (28, 28)))
plt.draw()
第一行代表原始的样本图片;第二行代表变分自编码重建后生成的图片,可以看到生成的数字不再一味单纯地学习形状,而是通过数据分布的方式学习规律,对原有图片具有更清晰的修正功能。
将均值和方差代表的二维数据在直角坐标系中展现如下。
6.高斯分布取样,生成模拟数据
为了进一步验证模型学到的数据分布的情况,在高斯分布抽样中取一些点,将其映射到模型中的z,然后通过解码部分还原啊成真实图片。
n = 15 # figure with 15x15 digits
digit_size = 28
figure = np.zeros((digit_size * n, digit_size * n))
grid_x = norm.ppf(np.linspace(0.05, 0.95, n))
grid_y = norm.ppf(np.linspace(0.05, 0.95, n))
for i, yi in enumerate(grid_x):
for j, xi in enumerate(grid_y):
z_sample = np.array([[xi, yi]])
x_decoded = sess.run(reconstructionout,feed_dict={zinput:z_sample})
digit = x_decoded[0].reshape(digit_size, digit_size)
figure[i * digit_size: (i + 1) * digit_size,
j * digit_size: (j + 1) * digit_size] = digit
plt.figure(figsize=(10, 10))
plt.imshow(figure, cmap='Greys_r')
plt.show()
可以看出,在神经网络的世界里,从左下角到右上角显示了网络时按照图片的形状变换而排列的,并不像类一样,把数字按照1到9的顺序排列,因为机器学习的知识图片。