之前的那篇文章中已经讲过了自编码器的基本结构,并且提到一个很好玩的用途——使用自编码器给图片去噪。其实基本过程还是很简单的:所谓图片去噪,不就是要训练出一个带噪图片(noise)到去噪图片(denoise)的一个映射吗?那么我们只需要在一堆干净的图片上搞点噪音,然后将这些噪音图片作为自编码器的输入,原来干净的图片作为标签值(目标值),来让自编码器做监督学习就行了。
基本结构有了,后面的玩法就多种多样了。不过涉及到图片数据,一般来说,大家都会使用卷积算子作为前向传播中信息提取的算子。由此,自编码器的encoder便很简单,就由一层层的卷积池化来完成,中间可以加入Relu和batchnormalization的操作。
等到decoder后,为了完成上采样的任务,我们有两种做法:第一种就是使用转置卷积来增加特征映射的维度,直到最终特征映射回到encoder输入图片的大小,那么decoder就完成了数据的重构;第二种便是先暴力上采样(比如使用双线性差值)来扩大特征映射的维度,等到扩充得和原本输入的图片差不多大时,我们再使用卷积进行像素点的调整。
下面使用第一种方法来作为decoder的主体。
本次使用的数据集是STL10,下载链接如下:
STL10cs.stanford.edu先导入需要的库:
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from sklearn.model_selection import train_test_split
from skimage.util import random_noise
from skimage.metrics import peak_signal_noise_ratio
import torch
from torch import nn
import torch.nn.functional as F
import torch.utils.data as Data
import torch.optim as optim
from torchvision import transforms
from torchvision.datasets import STL10
import hiddenlayer as hl
%matplotlib inline
第一步:准备数据
首先,从上述链接中下载完STL10的数据集后,开始对数据做预处理。STL10中都是96*96的RGB图片,训练集放在了train_X.bin的文件中,可以用做自编码器的无监督学习,下面是第数据的预处理:
def read_image(data_path):
with open(data_path, "rb") as f:
data1 = np.fromfile(f, dtype=np.uint8)
# 塑形成[batch, c, h, w]
images = np.reshape(data1, [-1, 3, 96, 96])
# 图像转化为RGB(即最后一个维度是通道维度)的形式,方便使用matplotlib进行可视化
images = np.transpose(images, [0, 3, 2, 1])
return images / 255
data_path = "../数据集/STL10/stl10_binary/train_X.bin"
images = read_image(data_path)
images.shape
out:
(5000, 96, 96, 3)
下面定义一个函数,为干净的图片添加高斯噪音,这部分添加了噪音的数据,将成为自编码器的输入。其中的random_noise是属于skimage.util下的一个方法。
def gaussian_noise(images, sigma):
"""sigma: 噪声标准差"""
sigma2 = sigma**2 / (255 ** 2) # 噪声方差
images_noisy = np.zeros_like(images)
for ii in range(images.shape[0]):
image = images[ii]
# 使用skimage中的函数增加噪音
noise_im = random_noise(image, mode="gaussian", var=sigma2, clip=True)
images_noisy[ii] = noise_im
return images_noisy
images_noise = gaussian_noise(images, 30)
print("image_noise:", images_noise.min(), "~", images_noise.max())
out:
image_noise: 0.0 ~ 1.0
下面可视化一些添加了噪音后的图片,其中iamges[ii, ...]中的...是numpy库中的语法糖,作用等价于若干个:,:,:的组合,也就是数组剩余的每个维度都全取。
plt.figure(figsize=[6, 6])
for ii in np.arange(36):
plt.subplot(6, 6, ii + 1)
plt.imshow(images[ii, ...])
plt.axis("off")
plt.show()
# 带噪音的图片
plt.figure(figsize=[6, 6])
for ii in np.arange(36):
plt.subplot(6, 6, ii + 1)
plt.imshow(images_noise[ii, ...])
plt.show()
out: