本文主要介绍使用TensorFlow实现DenoisingAutoencoder(图像去噪自动编码器)。
下面是示例代码:
# 导入相关模块
import numpy as np
import sys
import tensorflow as tf
import matplotlib.pyplot as plt
'''
IPython有一组预定义的“魔术函数”,您可以使用命令行样式语法调用
它们。有两种魔法,一种是线导向(line-oriented),另一种是单元
导向(cell-oriented)。line magics以%字符作为前缀,其工作方式
与操作系统命令行调用非常相似:它们作用于整行,line magics可以返
回结果,也可以进行赋值使用;cell magics是以%%开头,它需要出现
在单元的第一行,而且是作用于整个单元。
使用此方法时,绘制命令的输出将在前端显示,就像Jupyter笔记本一样
,直接显示在生成命令的代码单元格的下方,生成的绘图也将存储在笔记
本文档中。不过这个方法好像只适用于Jupyter notebook和Jupyter
QtConsole。
'''
%matplotlib inline
# 导入数据集
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
inputs_ = tf.placeholder(tf.float32, [None, 28, 28, 1])
targets_ = tf.placeholder(tf.float32, [None, 28, 28, 1])
def lrelu(x, alpha=0.1):
return tf.maximum(alpha * x, x)
### Encoder
with tf.name_scope('en-convolutions'):
conv1 = tf.layers.conv2d(inputs_, filters=32,
kernel_size=(3, 3),
strides=(1, 1),
padding='SAME',
use_bias=True,
activation=lrelu,)
# now 28x28x32
with tf.name_scope('en-pooling'):
maxpool1 = tf.layers.max_pooling2d(conv1,
pool_size=(2, 2),
strides=(2,2),)
# now 14x14x32
with tf.name_scope('en-convolutions'):
conv2 = tf.layers.conv2d(maxpool1,
filters=32,
kernel_size=(3, 3),
strides=(1,1),
padding='SAME',
use_bias=True,
activation=lrelu,)
# now 14x14x32
with tf.name_scope('encoding'):
encoded = tf.layers.max_pooling2d(conv2,
pool_size=(2,2),
strides=(2,2),)
# now 7x7x32
### Decoder
with tf.name_scope('decoder'):
conv3 = tf.layers.conv2d(encoded,
filters=32,
kernel_size=(3, 3),
strides=(1,1),
padding='SAME',
use_bias=True,
activation=lrelu)
# 7x7x32
upsamples1 = tf.layers.conv2d_transpose(conv3,
filters=32,
kernel_size=3,
padding='SAME',
strides=2,
name='upsample1')
# now 14x14x32
upsamples2 = tf.layers.conv2d_transpose(upsamples1,
filters=32,
kernel_size=3,
padding='SAME',
strides=2,
name='upsamples2')
# now 28x28x32
logits = tf.layers.conv2d(upsamples2,
filters=1,
kernel_size=(3, 3),
strides=(1, 1),
name='logits',
padding='SAME',
use_bias=True)
# now 28x28x1
# 通过sigmoid传递logits以获得重建图像
decoded = tf.sigmoid(logits, name='recon')
# 定义损失函数和优化器
loss = tf.nn.sigmoid_cross_entropy_with_logits(
logits=logits, labels=targets_)
learning_rate = tf.placeholder(tf.float32)
cost = tf.reduce_mean(loss)
opt = tf.train.AdamOptimizer(learning_rate).minimize(cost)
# 训练
sess = tf.Session()
saver = tf.train.Saver()
loss = []
valid_loss = []
display_step = 1
epochs = 25
batch_size = 64
lr =1e-5
sess.run(tf.global_variables_initializer())
writer = tf.summary.FileWriter('./graphs', sess.graph)
for e in range(epochs):
total_batch = int(mnist.train.num_examples / batch_size)
for ibatch in range(total_batch):
batch_x = mnist.train.next_batch(batch_size)
batch_test_x = mnist.test.next_batch(batch_size)
imgs_test = batch_x[0].reshape((-1, 28, 28, 1))
noise_factor = 0.5
x_test_noisy = imgs_test + noise_factor * np.random.normal(loc=0.0, scale=1.0, size=imgs_test.shape)
x_test_noisy = np.clip(x_test_noisy, 0., 1.)
imgs = batch_x[0].reshape((-1, 28, 28, 1))
x_train_noisy = imgs + noise_factor * np.random.normal(loc=0.0, scale=1.0, size=imgs.shape)
x_train_noisy = np.clip(x_train_noisy, 0., 1.)
batch_cost, _ = sess.run([cost, opt], feed_dict={inputs_: x_train_noisy,
targets_: imgs,learning_rate:lr})
batch_cost_test = sess.run(cost, feed_dict={inputs_: x_test_noisy,
targets_: imgs_test})
if (e+1) % display_step == 0:
print("Epoch: {}/{}...".format(e+1, epochs),
"Training loss: {:.4f}".format(batch_cost),
"Validation loss: {:.4f}".format(batch_cost_test))
loss.append(batch_cost)
valid_loss.append(batch_cost_test)
plt.plot(range(e+1), loss, 'bo', label='Training loss')
plt.plot(range(e+1), valid_loss, 'r', label='Validation loss')
plt.title('Training and validation loss')
plt.xlabel('Epochs ',fontsize=16)
plt.ylabel('Loss',fontsize=16)
plt.legend()
plt.figure()
plt.show()
saver.save(sess, 'encode_model')
batch_x= mnist.test.next_batch(10)
imgs = batch_x[0].reshape((-1, 28, 28, 1))
noise_factor = 0.5
x_test_noisy = imgs + noise_factor * np.random.normal(loc=0.0, scale=1.0, size=imgs.shape)
x_test_noisy = np.clip(x_test_noisy, 0., 1.)
recon_img = sess.run([decoded], feed_dict={inputs_: x_test_noisy})[0]
plt.figure(figsize=(20, 4))
plt.title('Reconstructed Images')
print("Original Images")
for i in range(10):
plt.subplot(2, 10, i+1)
plt.imshow(imgs[i, ..., 0], cmap='gray')
plt.show()
plt.figure(figsize=(20, 4))
print("Noisy Images")
for i in range(10):
plt.subplot(2, 10, i+1)
plt.imshow(x_test_noisy[i, ..., 0], cmap='gray')
plt.show()
plt.figure(figsize=(20, 4))
print("Reconstruction of Noisy Images")
for i in range(10):
plt.subplot(2, 10, i+1)
plt.imshow(recon_img[i, ..., 0], cmap='gray')
plt.show()
writer.close()
sess.close()
欢迎关注我的公众号:
编程技术与生活(ID:hw_cchang)