自动编码器是一种特定得深度学习模型, 那么我们想要改进有哪些方法呢?
首先我们可以在深度上面考虑: 神经网络为什么要“深”?
因为深度达到一定程度之后, 可以提升泛化能力, 所以神经网络要深.
其次我们可以在稀疏性上考虑: 压缩特征一定要少于输入特征吗?
自动编码器隐层的神经元数量少于输入层数量,称为欠完备. 而自动编码器隐层的神经元数量大于输入层数量,称为过完备. 欠完备神经元少, 这样就比较稠密, 而过完备神经元多, 这样就比较稀疏. 所以推荐使用逐层减少的架构(欠完备),然后加入稀疏惩罚.
最后我们可以在抗噪上进行考虑:如何让模型能够对抗噪声,提升泛化性?
深度学习模型并不可靠,数据加入微小噪声,结果会天差地别. 所以我们需要让模型负重前行,主动让数据充满噪声.
下面我们举个栗子:
在这里我们使用的损失函数是:loss = , 还有其他图像数据集增强方法:随机旋转,翻转图像均可有效提升模型泛化能力.
下面我们梳理一下改进自编码器的代码书写流程:
1. 数据读取--> 数据集扩充(去噪)
2. 模型搭建-->卷积编码器和反卷积编码器
3. 模型训练-->添加L1惩罚项(稀疏)
4. 模型比较
下面我们按照步骤来进行代码的编写
1. 数据读取--> 数据集扩充(去噪)
import torch
import torchvision
import numpy as np
import matplotlib.pyplot as plt
import torch.utils.data as Data
import torchvision.transforms as T
import torch.nn as nn
# 数据预处理:随机剪裁,随机旋转,翻转,标准化处理
tranform = T.Compose([
T.RandomCrop(32, padding=4), # 随机剪裁
T.RandomRotation(10), # 随机旋转
T.RandomHorizontalFlip(), # 随机翻转
T.ToTensor(), # 1.将数据转换为tensor,2.对数据进行归一化:0-1
T.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)) # 标准化为----》 -1——1
])
# 读取cifar10数据集
train_data = torchvision.datasets.CIFAR10(
root="./data/cifar_data",
train=True,
transform=tranform,
download=True,
)
print(train_data.data.shape) # [50000,32,32,3]
2. 模型搭建-->卷积编码器和反卷积编码器
在这里先介绍一下卷积计算公式: out= [(input - F + 2P)/S]+ 1 说明: input:输入大小, F:卷积核大小 P:零填充大小 S:跨步大小.
反卷积的计算公式: out= s*(input−1)+k−2p。 比如k=3,p=1,s=2, input=10,输出为:2*(10-1)+3-2*1=19
class AutoEncoder(torch.nn.Module):
def __init__(self):
# 构建卷积编码器
super(AutoEncoder, self).__init__()
ndf = 128 # 卷积核数量
self.encoder = nn.Sequential(
# out =(input - F + 2p)/ s +1 out=(32-4+2*1)/2 +1 =16
# input =[none,3,32,32]
nn.Conv2d(in_channels=3, out_channels=ndf, kernel_size=4, stride=2, padding=1), # out =[none,128,16,16]
nn.BatchNorm2d(ndf),
nn.ReLU(inplace=True),
# in =[none,ndf,16,16] ----> out=[none,2*ndf,8,8]
nn.Conv2d(in_channels=ndf, out_channels=2 * ndf, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(2 * ndf),
nn.LeakyReLU(0.2, inplace=True),
# in = [none,2*ndf,8,8] ----->out=[none,4*ndf,4,4]
nn.Conv2d(in_channels=2 * ndf, out_channels=4 * ndf, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(4 * ndf),
nn.LeakyReLU(0.2, inplace=True),
)
self.decoder = nn.Sequential(
# out = s *(input -1) +k -2p out=2*(4-1)+ 4 -2*1=8
# in =[ none,4*ndf,4,4] ----> out=[none,2*ndf,8,8]
nn.ConvTranspose2d(in_channels=4 * ndf, out_channels=2 * ndf, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(2 * ndf),
nn.LeakyReLU(0.2, inplace=True),
# in =[none,2*ndf,8,8] ----->out=[none,ndf,16,16]
nn.ConvTranspose2d(in_channels=2 * ndf, out_channels=ndf, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(ndf),
nn.LeakyReLU(0.2, inplace=True),
# in =[none,ndf,16,16] ----->out[none,3,32,32]
nn.ConvTranspose2d(in_channels=ndf, out_channels=3, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(3),
nn.Tanh(), # -1---1
)
def forward(self, x):
encode = self.encoder(x)
decode = self.decoder(encode)
return encode, decode
autoencoder = AutoEncoder().cuda()
print(autoencoder)
3. 模型训练-->添加L1惩罚项(稀疏)
# 超参数设置
EPOCH = 10
BATCH_SIZE = 64
N_TEST_IMG = 5
LR = 0.005
optimizer = torch.optim.Adam(autoencoder.parameters(), lr=LR,weight_decay=0.01)
lamda = 0.01
loss_fuc = nn.MSELoss() # 最小均方误差损失
f, a = plt.subplots(2, N_TEST_IMG, figsize=(5, 5))
plt.ion() # 自动关闭图像
train_loader = Data.DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)
# 训练模型
for epoch in range(EPOCH):
for i, (bx, by) in enumerate(train_loader):
bx = bx.cuda()
encode, decode = autoencoder(bx)
# 计算L1损失
# reg_loss = 0
# for parm in autoencoder.parameters():
# reg_loss += torch.sum(abs(parm))
# loss = loss_fuc(decode, bx) + lamda * reg_loss
loss = loss_fuc(decode,bx)
optimizer.zero_grad()
loss.backward()
optimizer.step()
if i % 500 == 0:
print("Epoch:,", epoch, " train loss:", loss.item())
_,decode =autoencoder(bx) # decode=[bx,3,32,32]
decode = decode.data.cpu().numpy().transpose([0,2,3,1]) # -1---1
decode = (decode+1)/2
bx = bx.data.cpu().numpy().transpose([0,2,3,1])
bx =(bx+1)/2
for j in range(N_TEST_IMG):
a[0][j].clear()
a[0][j].imshow(bx[j])
a[0][j].set_xticks(())
a[0][j].set_yticks(())
a[1][j].clear()
a[1][j].imshow(decode[j])
a[1][j].set_xticks(())
a[1][j].set_yticks(())
plt.draw()
plt.pause(0.05)
plt.ioff()
plt.show()
4. 模型比较
无添加惩罚项
添加L1稀疏惩罚项
问题思考: 为什么无添加的惩罚项看似更好于其他项?