#!/usr/bin/env python3#-*- coding: utf-8 -*-
"""Created on Mon Jan 1 12:45:57 2018
@author: pc"""
importtorchimporttorch.nn as nnfrom torch.autograd importVariableimporttorch.utils.data as Dataimporttorchvisionimportmatplotlib.pyplot as pltfrom mpl_toolkits.mplot3d importAxes3Dfrom matplotlib importcmimportnumpy as np#torch.manual_seed(1) # reproducible
#Hyper Parameters
EPOCH = 10BATCH_SIZE= 100LR= 0.005 #learning rate
DOWNLOAD_MNIST =True
N_TEST_IMG= 5
#Mnist digits dataset
train_data =torchvision.datasets.MNIST(
root='./mnist/',
train=True, #this is training data
transform=torchvision.transforms.ToTensor(), #Converts a PIL.Image or numpy.ndarray to
#torch.FloatTensor of shape (C x H x W) and normalize in the range [0.0, 1.0]
download=DOWNLOAD_MNIST, #download it if you don't have it
)#plot one example
print(train_data.train_data.size()) #(60000, 28, 28)
print(train_data.train_labels.size()) #(60000)
plt.imshow(train_data.train_data[2].numpy(), cmap='gray')
plt.title('%i' % train_data.train_labels[2])
plt.show()#Data Loader for easy mini-batch return in training, the image batch shape will be (100, 1, 28, 28)
train_loader = Data.DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)classAutoEncoder(nn.Module):def __init__(self):
super(AutoEncoder, self).__init__()
self.encoder=nn.Sequential(
nn.Linear(28*28, 128),
nn.Tanh(),
nn.Linear(128, 64),
nn.Tanh(),
nn.Linear(64, 12),
nn.Tanh(),
nn.Linear(12, 3), #compress to 3 features which can be visualized in plt
)
self.decoder=nn.Sequential(
nn.Linear(3, 12),
nn.Tanh(),
nn.Linear(12, 64),
nn.Tanh(),
nn.Linear(64, 128),
nn.Tanh(),
nn.Linear(128, 28*28),
nn.Sigmoid(),#compress to a range (0, 1)
)defforward(self, x):
encoded=self.encoder(x)
decoded=self.decoder(encoded)returnencoded, decoded
autoencoder=AutoEncoder()
optimizer= torch.optim.Adam(autoencoder.parameters(), lr=LR)
loss_func=nn.MSELoss()#initialize figure
f, a = plt.subplots(2, N_TEST_IMG, figsize=(5, 2))
plt.ion()#continuously plot
#original data (first row) for viewing
view_data = Variable(train_data.train_data[:N_TEST_IMG].view(-1, 28*28).type(torch.FloatTensor)/255.)for i inrange(N_TEST_IMG):
a[0][i].imshow(np.reshape(view_data.data.numpy()[i], (28, 28)), cmap='gray'); a[0][i].set_xticks(()); a[0][i].set_yticks(())for epoch inrange(EPOCH):for step, (x, y) inenumerate(train_loader):
b_x= Variable(x.view(-1, 28*28)) #batch x, shape (batch, 28*28)
b_y = Variable(x.view(-1, 28*28)) #batch y, shape (batch, 28*28)
b_label = Variable(y) #batch label
encoded, decoded=autoencoder(b_x)
loss= loss_func(decoded, b_y) #mean square error
optimizer.zero_grad() #clear gradients for this training step
loss.backward() #backpropagation, compute gradients
optimizer.step() #apply gradients
if step % 100 ==0:print('Epoch:', epoch, '| train loss: %.4f' %loss.data[0])#plotting decoded image (second row)
_, decoded_data =autoencoder(view_data)for i inrange(N_TEST_IMG):
a[1][i].clear()
a[1][i].imshow(np.reshape(decoded_data.data.numpy()[i], (28, 28)), cmap='gray')
a[1][i].set_xticks(()); a[1][i].set_yticks(())
plt.draw(); plt.pause(0.05)
plt.ioff()
plt.show()#visualize in 3D plot