python自编码器_深度学习之自编码器 示例

#!/usr/bin/env python3#-*- coding: utf-8 -*-

"""Created on Mon Jan 1 12:45:57 2018

@author: pc"""

importtorchimporttorch.nn as nnfrom torch.autograd importVariableimporttorch.utils.data as Dataimporttorchvisionimportmatplotlib.pyplot as pltfrom mpl_toolkits.mplot3d importAxes3Dfrom matplotlib importcmimportnumpy as np#torch.manual_seed(1) # reproducible

#Hyper Parameters

EPOCH = 10BATCH_SIZE= 100LR= 0.005 #learning rate

DOWNLOAD_MNIST =True

N_TEST_IMG= 5

#Mnist digits dataset

train_data =torchvision.datasets.MNIST(

root='./mnist/',

train=True, #this is training data

transform=torchvision.transforms.ToTensor(), #Converts a PIL.Image or numpy.ndarray to

#torch.FloatTensor of shape (C x H x W) and normalize in the range [0.0, 1.0]

download=DOWNLOAD_MNIST, #download it if you don't have it

)#plot one example

print(train_data.train_data.size()) #(60000, 28, 28)

print(train_data.train_labels.size()) #(60000)

plt.imshow(train_data.train_data[2].numpy(), cmap='gray')

plt.title('%i' % train_data.train_labels[2])

plt.show()#Data Loader for easy mini-batch return in training, the image batch shape will be (100, 1, 28, 28)

train_loader = Data.DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)classAutoEncoder(nn.Module):def __init__(self):

super(AutoEncoder, self).__init__()

self.encoder=nn.Sequential(

nn.Linear(28*28, 128),

nn.Tanh(),

nn.Linear(128, 64),

nn.Tanh(),

nn.Linear(64, 12),

nn.Tanh(),

nn.Linear(12, 3), #compress to 3 features which can be visualized in plt

)

self.decoder=nn.Sequential(

nn.Linear(3, 12),

nn.Tanh(),

nn.Linear(12, 64),

nn.Tanh(),

nn.Linear(64, 128),

nn.Tanh(),

nn.Linear(128, 28*28),

nn.Sigmoid(),#compress to a range (0, 1)

)defforward(self, x):

encoded=self.encoder(x)

decoded=self.decoder(encoded)returnencoded, decoded

autoencoder=AutoEncoder()

optimizer= torch.optim.Adam(autoencoder.parameters(), lr=LR)

loss_func=nn.MSELoss()#initialize figure

f, a = plt.subplots(2, N_TEST_IMG, figsize=(5, 2))

plt.ion()#continuously plot

#original data (first row) for viewing

view_data = Variable(train_data.train_data[:N_TEST_IMG].view(-1, 28*28).type(torch.FloatTensor)/255.)for i inrange(N_TEST_IMG):

a[0][i].imshow(np.reshape(view_data.data.numpy()[i], (28, 28)), cmap='gray'); a[0][i].set_xticks(()); a[0][i].set_yticks(())for epoch inrange(EPOCH):for step, (x, y) inenumerate(train_loader):

b_x= Variable(x.view(-1, 28*28)) #batch x, shape (batch, 28*28)

b_y = Variable(x.view(-1, 28*28)) #batch y, shape (batch, 28*28)

b_label = Variable(y) #batch label

encoded, decoded=autoencoder(b_x)

loss= loss_func(decoded, b_y) #mean square error

optimizer.zero_grad() #clear gradients for this training step

loss.backward() #backpropagation, compute gradients

optimizer.step() #apply gradients

if step % 100 ==0:print('Epoch:', epoch, '| train loss: %.4f' %loss.data[0])#plotting decoded image (second row)

_, decoded_data =autoencoder(view_data)for i inrange(N_TEST_IMG):

a[1][i].clear()

a[1][i].imshow(np.reshape(decoded_data.data.numpy()[i], (28, 28)), cmap='gray')

a[1][i].set_xticks(()); a[1][i].set_yticks(())

plt.draw(); plt.pause(0.05)

plt.ioff()

plt.show()#visualize in 3D plot

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值