本篇博客主要介绍PyTorch中的自编码(AutoEncoder),并使用自编码来实现非监督学习。
示例代码:
import torch
import torch.nn as nn
import torch.utils.data as Data
import torchvision
from torch.autograd import Variable
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from matplotlib import cm
import numpy as np
# 超参数
EPOCH = 10
BATCH_SIZE = 64
LR = 0.005
DOWNLOAD_MNIST = False
N_TEST_IMG = 5
# 下载MNIST数据
train_data = torchvision.datasets.MNIST(
root='./mnist/',
train=True,
transform=torchvision.transforms.ToTensor(),
download=DOWNLOAD_MNIST,
)
# 输出一个样本
# print(train_data.train_data.size())
# print(train_data.train_labels.size