用pytorch实现auto-encoder

记录一下自己的学习过程

AE很久之前就被提出,一经提出就被广泛使用,原因是比较大部分的网络,它采用的是无监督学习方式。AE的提出不仅仅是为了重建图像,而是为了利用这个网络将图像的特征提取出来,例如添加了噪声的mnist也可以通过AE提取图片的特征从而恢复图片的像素值。这都是后话了,这篇就单纯讲图像重建,

                                                                Fashionmnist数据集

                                                

                                                                         AE

本文用的数据集是fashionmnist数据集,框架是pytorch(个人觉得还是tensorflow好用点)

下面放入代码:

import torch
import torchvision
import os
import torch.optim as optim
import torch.nn as nn
import numpy as np
from torchvision import datasets
import torch.nn.functional as F
import torchvision.transforms as transforms
from torchvision.utils import save_image
import matplotlib.pyplot as plt

transform=transforms.Compose([
    transforms.ToTensor(),
    #transforms.Normalize((0.5),(0.5))
    ])

device=torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

def get_dir():
    image_dir = 'FashionMNIST_Images'
    if not os.path.exists(image_dir):
        os.makedirs(image_dir)

def get_reconstruction_img(img,epoch):
    img=img.view(img.size(0),1,28,28)
    save_image(img, './FashionMNIST_Images/linear_ae_image{}.png'.format(epoch+1))

trainset=datasets.FashionMNIST(
    root='./fashionm',
    train=True,
    download=True,
    transform=transform
    
    )
trainloader=torch.utils.data.DataLoader(
    trainset,
    batch_size=128,
    shuffle=True
    
    
    )
testset=datasets.FashionMNIST(
    root='./fashionm',
    train=False,
    download=True,
    transform=transform
    
    )
testloader=torch.utils.data.DataLoader(
    testset,
    batch_size=128,
    shuffle=False
    
    
    )

class Net(nn.Module):
    def __init__(self):
        # encoder
        super(Net,self).__init__()
        self.enc1 = nn.Linear(in_features=784, out_features=256)
        self.enc2 = nn.Linear(in_features=256, out_features=128)
        self.enc3 = nn.Linear(in_features=128, out_features=64)
        self.enc4 = nn.Linear(in_features=64, out_features=32)
        self.enc5 = nn.Linear(in_features=32, out_features=16)
        # decoder 
        self.dec1 = nn.Linear(in_features=16, out_features=32)
        self.dec2 = nn.Linear(in_features=32, out_features=64)
        self.dec3 = nn.Linear(in_features=64, out_features=128)
        self.dec4 = nn.Linear(in_features=128, out_features=256)
        self.dec5 = nn.Linear(in_features=256, out_features=784)
    def forward(self, x):
        x = F.relu(self.enc1(x))
        x = F.relu(self.enc2(x))
        x = F.relu(self.enc3(x))
        x = F.relu(self.enc4(x))
        x = F.relu(self.enc5(x))
        x = F.relu(self.dec1(x))
        x = F.relu(self.dec2(x))
        x = F.relu(self.dec3(x))
        x = F.relu(self.dec4(x))
        x = self.dec5(x)
        return x
    
net=Net().to(device)
    
optimizer=optim.Adam(net.parameters(),lr=1e-3)
criterion=nn.MSELoss()
    
def train():
    train_loss=[]
    for epoch in range(5):
        running_loss=0
        for step,data in enumerate(trainloader):
             train_x,target=data[0].to(device,non_blocking=True),data[1].to(device,non_blocking=True)
             train_x=train_x.view(train_x.size(0),-1)
             optimizer.zero_grad()
             output=net(train_x)
             loss=criterion(output,train_x)
             loss.backward()
             optimizer.step()
             running_loss+=loss.item()
             if step%100==99:
                 loss=running_loss/100
                 train_loss.append(loss)
                 running_loss=0
                 print("epoch%d  step%d  loss%.2f"%(epoch+1,step+1,loss))
                 
        if epoch%5==4:
             get_reconstruction_img(output.cpu().data,epoch)
    return train_loss
get_dir()                 
train_loss=train() 
def test_image_reconstruction(net, testloader):
     for step,batch in enumerate(testloader):
        img, _ = batch
        img = img.to(device)
        img = img.view(img.size(0), -1)
        outputs = net(img)
        outputs = outputs.view(outputs.size(0), 1, 28, 28).cpu().data
        
        save_image(outputs, './FashionMNIST_Images/fashionmnist_reconstruction{}.png'.format(step+1))
        break
test_image_reconstruction(net, testloader)

plt.figure()
plt.plot(train_loss,"r-")
plt.title("Train_loss")
plt.ylabel("loss")
plt.show()

有些电脑对于device这一块可能不能正常运行,也可以创建一个函数:

def get_device():
    if torch.cuda.is_available():
        device = 'cuda:0'
    else:
        device = 'cpu'
    return device

  • 1
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值