AE.py
#!/usr/bin/env python
# coding: utf-8
# In[1]:
import torch
from torch import nn
class AE(nn.Module):
def __init__(self):
super(AE,self).__init__()
#[b,784] =>[b,20]
self.encoder =nn.Sequential(
nn.Linear(784,256),
nn.ReLU(),
nn.Linear(256,64),
nn.ReLU(),
nn.Linear(64,20),
nn.ReLU()
)
#[b,20]=>[b,784]
self.decoder =nn.Sequential(
nn.Linear(20,64),
nn.ReLU(),
nn.Linear(64,256),
nn.ReLU(),
nn.Linear(256,784),
nn.Sigmoid()
)
def forward(self,x):
batchsz = x.size(0)
#flatten
x = x.view(batchsz,784)
#encoder
x = self.encoder(x)
# decoder
x = self.decoder(x)
#reshape
x=x.view(batchsz,1,28,28)
return x,None
AEmain.py
import torch
from torch.utils.data import DataLoader
from torch import nn, optim
from torchvision import transforms, datasets
from ae import AE
import visdom
def main():
mnist_train =datasets.MNIST('../data',True,transform=transforms.Compose([transforms.ToTensor()]),download=False)
mnist_train =DataLoader(mnist_train,batch_size=32,shuffle=True)
mnist_test =datasets.MNIST('../data',False,transform=transforms.Compose([transforms.ToTensor()]),download=False)
mnist_test =DataLoader(mnist_test,batch_size=32,shuffle=False)
x,_i=iter(mnist_train).next()
print("x:",x.shape)
device = torch.device('cuda')
model =AE().to(device)
criteon = nn.MSELoss()
optimizer = optim.Adam(model.parameters(),lr=0.0001)
print(model)
viz =visdom.Visdom()
for epoch in range(1000):
for batchidx,(x,_) in enumerate(mnist_train):
#[b,1,28,28]
x = x.to(device)
x_hat,_=model(x)
loss=criteon(x_hat,x)
# backprop
optimizer.zero_grad() # 梯度转化为0
loss.backward() # loss 反向传播求梯度
optimizer.step()# 更新参数 w,b
print(epoch,'loss:',loss.item())
x,_ =iter(mnist_test).next()
x =x.to(device)
with torch.no_grad():
x_hat,_=model(x)
viz.images(x,nrow=8,win = 'x',opts=dict(title='x'))
viz.images(x_hat,nrow=8,win = 'x_hat',opts=dict(title='x_hat'))
if __name__=='__main__':
main()