PyTorch学习 | AutoEncoder 自编码器
参考:https://mofanpy.com/tutorials/machine-learning/torch/autoencoder/
1. AutoEncoder 简介
即将原数据压缩(Encoder)然后解码(Decoder),我们的目的是让输出尽可能跟输入一样,从而我们压缩保留的就应该是原数据中的精华信息。
感觉和GAN有点类似
2. PyTorch实现
这里我们尝试用MNIST手写数字集合用作输入,Decoder的目的是尽可能的还原原图。
我们中间的维度是3维,为了方便可视化。
import torch
import torch.nn as nn
import torch.utils.data as Data
import torchvision
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from matplotlib import cm
import numpy as np
BATCH_SIZE = 64
LR = 0.005 # learning rate
EPOCH = 10
##########
# 导入数据 #
##########
DOWNLOAD_MNIST = False # 如果之前没有下载过,要设置为True
train_data = torchvision.datasets.MNIST(
root='./mnist',
train=True, # this is training data
transform=torchvision.transforms.ToTensor(), # Converts a PIL.Image or numpy.ndarray to
# torch.FloatTensor of shape (C x H x W) and normalize in the range [0.0, 1.0]
download=DOWNLOAD_MNIST, # download it if you don't have it
)
print(train_data.data.size()) # (60000, 28, 28)
print(train_data.targets.size()) # (60000)
idx = 2 # 选一个数据进行可视化
plt.imshow(train_data.data[idx]