根据莫烦python课程在pytorch版本1.8.1上实现的。
莫烦python课程链接
import numpy as np
import torch
import torch.nn as nn
import torch.utils.data as Data
import torchvision
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
# 保证每次测试结果稳定
torch.manual_seed(0)
# 超参数
Epoch=6
batchsize=64
Lr=0.01
download_mnist=False
N_TEST_IMAGE=5
train_data=torchvision.datasets.MNIST(
root='./mnist',
train=True,
transform=torchvision.transforms.ToTensor(),
download=download_mnist,
)
train_loader=Data.DataLoader(dataset=train_data, batch_size=batchsize, shuffle=True)
class AutoEncoder(nn.Module):
def __init__(self):
super(AutoEncoder, self).__init__()
self.encoder=nn.Sequential(
nn.Linear(28*28, 128),
nn.Tanh(),
nn.Linear(128,64),
nn.Tanh(),
nn.Linear(64, 12),
nn