本文用深度学习的方法,在tensorflow2.0框架下实现了反卷积自编码器和上采样自编码器的去噪声,并对比了两者的效果. 在mnist数据集上,上采样去噪实际效果优于反卷积去噪
#!/usr/bin/env python
# coding: utf-8
# In[1]:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.compat.v1 import ConfigProto
from tensorflow.compat.v1 import InteractiveSession
config = ConfigProto()
config.gpu_options.allow_growth = True
session = InteractiveSession(config=config)
# In[2]:
(x_train,x_label),(y_train,y_label)=tf.keras.datasets.mnist.load_data()
x_train.shape,y_train.shape
# In[3]:
x_train = x_train.astype('float32')
y_train = y_train.astype('float32')
x_train = x_train/255
y_train = y_train/255
x_train = np.reshape(x_train,(len(x_train),28,28,1))
y_train = np.reshape(y_train,(len(y_train),28,28,1))
y_train.shape
# In[4]:
noise_factor = 0.5
noise_x = x_train + noise_factor*np.random.normal(loc=0.0, scale=1.0, size=x_train.shape)
noise_y = y_train + noise_factor*np.random.normal(loc=0.0, scale=1.0, size=y_train.shape)
# In[5]:
n = 20
plt.figure(figsize=(60, 12)) # 指定宽为20,高为4
for i in range(1, 11):
ax = plt.subplot(2, n/2, i) # 表示将设置为2行,n/2列,当前位置在i。
plt.imshow(x_train[i].reshape(28, 28))
plt.gray()
ax.get_xaxis().set_visible(False)
ax.get_yaxis().set_visible(False)
ax = plt.subplot(2, n/2, 10+i)
plt.imshow(noise_x[i].reshape(28, 28))
plt.gray()
ax.get_xaxis().set_visible(False)
ax.get_yaxis().set_visible(False)
plt.show()
![手写数据集原图像和加入噪声的图像对比](https://img-blog.csdnimg.cn/20200311013421216.png?x-oss-process=image/watermark,type_ZmFuZ3poZW5naGVpdGk,shadow_10,text_aHR0cHM6Ly9ibG9nLmNzZG4ubmV0L2hlbGxkb2dlcg==,size_16,color_FFFFFF,t_70)
# In[6]:
input = tf.keras.layers.Input(shape=(28,28,1))
x1 = tf.keras.layers.Conv2D(32,(3,3),padding='same',activation = 'relu',name = 'x1')(input) #变为28*28*32
x2 = tf.keras.layers.MaxPooling2D(pool_size=(2, 2),strides=(2, 2),padding='same',name = 'x2')(x1) #变为14*14*32
x3 = tf.keras.layers.Conv2D(64,(3,3),padding='same',activation = 'relu',name = 'x3')(x2)#变为14*14*64
x4 = tf.keras.layers.MaxPooling2D(pool_size=(2, 2),strides=(2, 2), padding="same",name = 'x4')(x3)#变为7*7*64
x5 = tf.keras.layers.Conv2D(64,(3,3),padding="same",activation = 'relu',name = 'x5')(x4)
x6 = tf.keras.layers.MaxPooling2D(pool_size=(2, 2),strides=(2, 2), padding="same",name = 'x6')(x5) #变为4*4*64
# In[7]:
x1.get_shape,x2.get_shape,x3.get_shape,x4.get_shape,x5.get_shape,x6.get_shape
# In[8]:
x7 = tf.keras.layers.Conv2DTranspose(32,(4,4),strides=(1, 1), padding='valid',name = 'x7')(x6)
#用4*4的卷积核对齐上一层,shape=(None, 7, 7, 32) ,padding = valid 自动填充,扩大了图片
x8 = tf.keras.layers.Conv2DTranspose(16,(2,2),strides=(2, 2), padding='same',name = 'x8')(x7) #放大图像为8*8*16
x9 = tf.keras.layers.Conv2DTranspose(1,(2,2),strides=(2, 2), padding='same',name = 'x9')(x8)#放大图像为16*16*8
x10 = tf.keras.layers.Conv2D(1,(3,3),padding="same",activation = 'sigmoid',name = 'x10')(x9)
# In[9]:
x7.get_shape,x8.get_shape,x9.get_shape,x10.shape
# In[10]:
XXX = tf.keras.Model(inputs = input,outputs = x9)
# In[11]:
optimizer_1 = tf.keras.optimizers.Adam(learning_rate=0.001)
# In[12]:
XXX.compile(optimizer=optimizer_1,
loss='binary_crossentropy',
metrics=['accuracy'])
# In[ ]:
XXX.fit(noise_x,x_train,
batch_size=64,
epochs=50,
shuffle=True,
validation_data=(noise_y,y_train)
)
# In[24]:
decoded_imgs = XXX.predict(noise_y)
# In[25]:
n = 10
plt.figure(figsize=(20, 4))
for i in range(1, n):
# display original
ax = plt.subplot(2, n, i)
plt.imshow(decoded_imgs[i].reshape(28, 28))
plt.gray()
ax.get_xaxis().set_visible(False)
ax.get_yaxis().set_visible(False)
# display reconstruction
ax = plt.subplot(2, n, i + n)
plt.imshow(noise_y[i].reshape(28, 28))
plt.gray()
ax.get_xaxis().set_visible(False)
ax.get_yaxis().set_visible(False)
plt.show()
![反卷积去燥与原图像的对比](https://img-blog.csdnimg.cn/20200311013523418.png?x-oss-process=image/watermark,type_ZmFuZ3poZW5naGVpdGk,shadow_10,text_aHR0cHM6Ly9ibG9nLmNzZG4ubmV0L2hlbGxkb2dlcg==,size_16,color_FFFFFF,t_70)
# In[ ]:
QQ = tf.keras.utils.plot_model(XXX, show_shapes=True)
# ## 上采样法
# In[62]:
input_2 = tf.keras.Input(shape=(28,28,1))
xx1 = tf.keras.layers.Conv2D(16,(3,3),padding='same',activation='relu',name = 'xx1')(input_2) #28*28*16
xx2 = tf.keras.layers.MaxPooling2D(pool_size=(2, 2), padding="same",name = 'xx2')(xx1)#strides = 2*2,14*14*16
xx3 = tf.keras.layers.Conv2D(8,(3,3),padding='same',activation='relu',name = 'xx3')(xx2) #14*14*8
xx4 = tf.keras.layers.MaxPooling2D(pool_size=(2, 2), padding="same",name = 'xx4')(xx3)#7*7*8
xx5 = tf.keras.layers.Conv2D(8,(3,3),padding='same',activation='relu',name = 'xx5')(xx4) #7*7*8
xx6 = tf.keras.layers.MaxPooling2D(pool_size=(2, 2), padding="same",name = 'xx6')(xx5)#4*4*8
# In[63]:
xx1.get_shape,xx2.get_shape,xx3.get_shape,xx4.get_shape,xx5.get_shape,xx6.get_shape
# In[64]:
xx7 = tf.keras.layers.Conv2D(8, (3, 3), activation="relu", padding="same",name = 'xx7')(xx6) #4*4*8
xx8 = tf.keras.layers.UpSampling2D((2, 2),name = 'xx8')(xx7) #上采样扩大图片 8*8*8
xx9 = tf.keras.layers.Conv2D(8, (3, 3), activation="relu", padding="same",name = 'xx9')(xx8) #8*8*8
xx10 = tf.keras.layers.UpSampling2D((2, 2),name = 'xx10')(xx9) #上采样扩大图片 16*16*8
xx11 = tf.keras.layers.Conv2D(16,(3,3),padding='valid',activation='relu',name = 'xx11')(xx10) #16*16*8## z注意,padding自适应恢复图形形状
xx12 = tf.keras.layers.UpSampling2D((2,2),name = 'xx12')(xx11) #上采样扩大图片 32*32*8
xx13 = tf.keras.layers.Conv2D(1,(3,3),padding='same',activation='sigmoid',name = 'xx13')(xx12) #16*16*1
# In[ ]:
# In[65]:
xx7.get_shape,xx8.get_shape,xx9.get_shape,xx10.get_shape,xx11.get_shape,xx12.get_shape,xx13.get_shape
# In[67]:
XDX = tf.keras.Model(inputs = input_2,outputs = xx13)
# In[68]:
XDX.summary()
# In[69]:
optimizer_2 = tf.keras.optimizers.Adam(learning_rate=0.001)
# In[70]:
XDX.compile(optimizer=optimizer_2,
loss = 'binary_crossentropy',
metrics = ['acc']
)
# In[71]:
history_2 = XDX.fit(noise_x,x_train,
batch_size=64,epochs=50,
shuffle=True,
validation_data=(noise_y,y_train))
# In[72]:
decoded_imgs_1 = XDX.predict(noise_y)
# In[73]:
n = 10
plt.figure(figsize=(20, 4))
for i in range(1, n):
# display original
ax = plt.subplot(2, n, i)
plt.imshow(decoded_imgs_1[i].reshape(28, 28))
plt.gray()
ax.get_xaxis().set_visible(False)
ax.get_yaxis().set_visible(False)
# display reconstruction
ax = plt.subplot(2, n, i + n)
plt.imshow(noise_y[i].reshape(28, 28))
plt.gray()
ax.get_xaxis().set_visible(False)
ax.get_yaxis().set_visible(False)
plt.show()