AI入门:神经网络实战----AutoEncoder

AutoEncoder是一种神经网络结构,用于去除输入数据中的噪音。通过压缩和解压缩过程,它能重建原始数据。本文介绍了简单的三层全连接AutoEncoder,使用MNIST数据集展示其在图像去噪上的效果,并讨论了改进方法。
摘要由CSDN通过智能技术生成

前言
我们使用的仪器、设备都是有误差的,所以我们拿到的原始数据也都是有误差 (噪音) 的。误差有正有负,但总的结果一般都是在真实结果附近。那么我们怎么才能去掉噪音得到真实结果呢?这就要我们的AutoEncoder出马了。

在这里插入图片描述

什么是AutoEncoder,为什么要用AutoEncoder
AutoEncoder是一个蝴蝶状的神经网络结构,通过中间的瓶颈把输入数据中的噪音压缩除去,最后再重新放大生成与原始数据相同尺寸的图片,这就是AutoEncoder,中文翻译为自动编码器。从这段文字的描述来看,AutoEncoder是一个生成式神经网络结构,主要用来去除原始数据中的一些噪音,然后重新生成原始数据。
AutoEncoder中可以只使用全连接层,也可以使用卷积层或者以后要学的其它结构。这里为了简单起见,我们只使用全连接层,且一共只有三层。

在这里插入图片描述

AutoEncoder相关函数
这个AutoEncoder神经网络结构比较简单,三层之间都是用全连接方式连接。为了防止过拟合,我们使用dropout(),激活函数是sigmoid()。
值得注意的是损失函数。我们希望生成去除噪音之后的原始图像,因此我们的生成数据应该是跟原始图片进行对比,计算损失。另外,这里使用均方差计算损失。

编写AutoEncoder
下面就是AutoEncoder的网络模型:

在这里插入图片描述

在定义完模型之后,我们再实例化模型,指定模型的优化方式为Adam,使用均方差作为损失函数。

在这里插入图片描述

AutoEncoder的训练和测试
这里我们使用MNIST数据集,数据集中的原始图片是不带噪音的。因此需要我们制造有噪音的数据:在原始数据上增加一个强度 * 0.3的标准正态分布的随机数据。这个随机数据的尺寸要跟原始图片的维度相一致。

data = data + torch.randn([data.size(0), 28 * 28]) * 0.3

有了这个带噪音的数据之后,AutoEncoder就可以开始训练了。训练的代码与全连接神经网络的训练过程完全一样。这里就不重复了。
区别是在训练完一轮数据之后,需要把前10个测试数据显示出来。显示代码如下:图像分为三行,第0行表示没有加噪音的原始数据;第1行表示加了噪音之后的原始数据;第2行表示AutoEncoder去噪之后的数据。

在这里插入图片描述

在这里插入图片描述
从上面的显示结果看,这个AutoEncoder的效果是一般,数字6、8、0都比较模糊。我们可以通过增加网络的深度、调节瓶颈的大小等方式进行改进。
总结
这一节里,我们提出了使用AutoEncoder去除图像中的噪音,采用的是三层全连接操层,中间两层都采用dropout防止过拟合。本节使用的AutoEncoder达到了去除噪音的效果,但是生成的新图片的效果一般。在下一节中,我们将使用变分自编码 (VAE)。VAE中将提出一种新的结构,敬请期待。

import torch
import torch.optim as optim
from torch.nn import functional as F
import torch.nn as nn
import torchvision
from torch.nn import Module
import matplotlib.pyplot as plt

#三个超参数
learning_rate = 0.01 #学习率
epochs = 10           #总的训练次数
batch_size = 128      #每个批次的样本数量
show_num = 10  # 显示的图片个数   
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") #GPU或CPU

train_loader = torch.utils.data.DataLoader(torchvision.datasets.MNIST('datasets/mnist_data',
                train=True,
                download=True,
                transform=torchvision.transforms.Compose([
                torchvision.transforms.ToTensor(),                       # 数据类型转化
                torchvision.transforms.Normalize((0.1307, ), (0.3081, )<
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值