记录一下自己的学习过程
AE很久之前就被提出,一经提出就被广泛使用,原因是比较大部分的网络,它采用的是无监督学习方式。AE的提出不仅仅是为了重建图像,而是为了利用这个网络将图像的特征提取出来,例如添加了噪声的mnist也可以通过AE提取图片的特征从而恢复图片的像素值。这都是后话了,这篇就单纯讲图像重建,
Fashionmnist数据集
AE
本文用的数据集是fashionmnist数据集,框架是pytorch(个人觉得还是tensorflow好用点)
下面放入代码:
import torch
import torchvision
import os
import torch.optim as optim
import torch.nn as nn
import numpy as np
from torchvision import datasets
import torch.nn.functional as F
import torchvision.transforms as transforms
from torchvision.utils import save_image
import matplotlib.pyplot as plt
transform=transforms.Compose([
transforms.ToTensor(),
#transforms.Normalize((0.5),(0.5))
])
device=torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
def get_dir():
image_dir = 'FashionMNIST_Images'
if not os.path.exists(image_dir):
os.makedirs(image_dir)
def get_reconstruction_img(img,epoch):
img=img.view(img.size(0),1,28,28)
save_image(img, './FashionMNIST_Images/linear_ae_image{}.png'.format(epoch+1))
trainset=datasets.FashionMNIST(
root='./fashionm',
train=True,
download=True,
transform=transform
)
trainloader=torch.utils.data.DataLoader(
trainset,
batch_size=128,
shuffle=True
)
testset=datasets.FashionMNIST(
root='./fashionm',
train=False,
download=True,
transform=transform
)
testloader=torch.utils.data.DataLoader(
testset,
batch_size=128,
shuffle=False
)
class Net(nn.Module):
def __init__(self):
# encoder
super(Net,self).__init__()
self.enc1 = nn.Linear(in_features=784, out_features=256)
self.enc2 = nn.Linear(in_features=256, out_features=128)
self.enc3 = nn.Linear(in_features=128, out_features=64)
self.enc4 = nn.Linear(in_features=64, out_features=32)
self.enc5 = nn.Linear(in_features=32, out_features=16)
# decoder
self.dec1 = nn.Linear(in_features=16, out_features=32)
self.dec2 = nn.Linear(in_features=32, out_features=64)
self.dec3 = nn.Linear(in_features=64, out_features=128)
self.dec4 = nn.Linear(in_features=128, out_features=256)
self.dec5 = nn.Linear(in_features=256, out_features=784)
def forward(self, x):
x = F.relu(self.enc1(x))
x = F.relu(self.enc2(x))
x = F.relu(self.enc3(x))
x = F.relu(self.enc4(x))
x = F.relu(self.enc5(x))
x = F.relu(self.dec1(x))
x = F.relu(self.dec2(x))
x = F.relu(self.dec3(x))
x = F.relu(self.dec4(x))
x = self.dec5(x)
return x
net=Net().to(device)
optimizer=optim.Adam(net.parameters(),lr=1e-3)
criterion=nn.MSELoss()
def train():
train_loss=[]
for epoch in range(5):
running_loss=0
for step,data in enumerate(trainloader):
train_x,target=data[0].to(device,non_blocking=True),data[1].to(device,non_blocking=True)
train_x=train_x.view(train_x.size(0),-1)
optimizer.zero_grad()
output=net(train_x)
loss=criterion(output,train_x)
loss.backward()
optimizer.step()
running_loss+=loss.item()
if step%100==99:
loss=running_loss/100
train_loss.append(loss)
running_loss=0
print("epoch%d step%d loss%.2f"%(epoch+1,step+1,loss))
if epoch%5==4:
get_reconstruction_img(output.cpu().data,epoch)
return train_loss
get_dir()
train_loss=train()
def test_image_reconstruction(net, testloader):
for step,batch in enumerate(testloader):
img, _ = batch
img = img.to(device)
img = img.view(img.size(0), -1)
outputs = net(img)
outputs = outputs.view(outputs.size(0), 1, 28, 28).cpu().data
save_image(outputs, './FashionMNIST_Images/fashionmnist_reconstruction{}.png'.format(step+1))
break
test_image_reconstruction(net, testloader)
plt.figure()
plt.plot(train_loss,"r-")
plt.title("Train_loss")
plt.ylabel("loss")
plt.show()
有些电脑对于device这一块可能不能正常运行,也可以创建一个函数:
def get_device():
if torch.cuda.is_available():
device = 'cuda:0'
else:
device = 'cpu'
return device