PyTorch-Tutorial项目中的自编码器实现详解

PyTorch-Tutorial项目中的自编码器实现详解

PyTorch-Tutorial Build your neural network easy and fast, 莫烦Python中文教学 PyTorch-Tutorial 项目地址: https://gitcode.com/gh_mirrors/pyt/PyTorch-Tutorial

自编码器概述

自编码器(AutoEncoder)是一种无监督学习的神经网络模型,主要用于数据的降维和特征提取。它通过将输入数据压缩到一个低维空间(编码),然后再从这个低维表示重建原始数据(解码),从而学习数据的有用特征。

代码实现解析

1. 数据准备

首先加载MNIST手写数字数据集,这是深度学习中最常用的基准数据集之一:

train_data = torchvision.datasets.MNIST(
    root='./mnist/',
    train=True,
    transform=torchvision.transforms.ToTensor(),
    download=DOWNLOAD_MNIST
)

数据加载器将数据集分成小批量,便于训练:

train_loader = Data.DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)

2. 自编码器模型结构

自编码器由编码器和解码器两部分组成:

class AutoEncoder(nn.Module):
    def __init__(self):
        super(AutoEncoder, self).__init__()
        
        # 编码器部分
        self.encoder = nn.Sequential(
            nn.Linear(28*28, 128),
            nn.Tanh(),
            nn.Linear(128, 64),
            nn.Tanh(),
            nn.Linear(64, 12),
            nn.Tanh(),
            nn.Linear(12, 3),  # 压缩到3维特征
        )
        
        # 解码器部分
        self.decoder = nn.Sequential(
            nn.Linear(3, 12),
            nn.Tanh(),
            nn.Linear(12, 64),
            nn.Tanh(),
            nn.Linear(64, 128),
            nn.Tanh(),
            nn.Linear(128, 28*28),
            nn.Sigmoid()  # 输出在0-1范围内
        )

编码器将784维(28×28)的输入图像逐步压缩到3维空间,解码器则从3维空间重建原始图像。使用Tanh作为激活函数,最后解码器输出使用Sigmoid确保像素值在0-1之间。

3. 训练过程

训练过程使用均方误差(MSE)作为损失函数,Adam优化器进行参数更新:

optimizer = torch.optim.Adam(autoencoder.parameters(), lr=LR)
loss_func = nn.MSELoss()

for epoch in range(EPOCH):
    for step, (x, b_label) in enumerate(train_loader):
        b_x = x.view(-1, 28*28)  # 展平输入
        b_y = x.view(-1, 28*28)  # 目标输出
        
        encoded, decoded = autoencoder(b_x)
        loss = loss_func(decoded, b_y)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

4. 可视化结果

训练过程中实时显示原始图像和重建图像的对比:

# 初始化图像显示
f, a = plt.subplots(2, N_TEST_IMG, figsize=(5, 2))
plt.ion()

# 训练过程中更新显示
_, decoded_data = autoencoder(view_data)
for i in range(N_TEST_IMG):
    a[1][i].imshow(np.reshape(decoded_data.data.numpy()[i], (28, 28)), cmap='gray')

训练完成后,将200个样本的3维编码结果在3D空间中可视化:

fig = plt.figure(2); ax = Axes3D(fig)
X, Y, Z = encoded_data.data[:, 0].numpy(), encoded_data.data[:, 1].numpy(), encoded_data.data[:, 2].numpy()
values = train_data.train_labels[:200].numpy()
for x, y, z, s in zip(X, Y, Z, values):
    c = cm.rainbow(int(255*s/9)); ax.text(x, y, z, s, backgroundcolor=c)

技术要点解析

  1. 数据预处理:MNIST图像被归一化到[0,1]范围,并展平为784维向量
  2. 网络设计:采用全连接层逐步压缩/扩展维度,使用Tanh激活函数防止梯度消失
  3. 损失函数:使用MSE衡量重建图像与原始图像的差异
  4. 可视化:3D编码空间展示不同数字类别的分布情况

实际应用场景

自编码器在实际中有多种应用:

  1. 数据降维:比PCA等线性方法更强大的非线性降维
  2. 异常检测:重建误差大的样本可能是异常值
  3. 图像去噪:训练时加入噪声,学习重建干净图像
  4. 特征提取:编码器部分可作为特征提取器

总结

本教程实现了一个基本的自编码器模型,展示了如何使用PyTorch构建、训练和评估自编码器。通过将784维的手写数字图像压缩到3维空间并重建,我们不仅理解了自编码器的工作原理,还通过可视化直观地看到了不同数字在潜在空间中的分布情况。

PyTorch-Tutorial Build your neural network easy and fast, 莫烦Python中文教学 PyTorch-Tutorial 项目地址: https://gitcode.com/gh_mirrors/pyt/PyTorch-Tutorial

创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

江燕娇

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值