我找到了自己改的代码,还加了几层网络,效果确实是比只有一层网络的好
(当然,这是在学长学姐们的帮助下修改完成的,尤其感谢文哥帮我调试并修改错误)
import torch
from torch.utils.data import Dataset
import torchvision.datasets as dsets
import torchvision.transforms as transforms
import torchvision
from torch.autograd import Variable
import torch.nn as nn
import os
import numpy as np
# from skimage import io
from PIL import Image
from time import time
#from test_ae import *
num_epochs = 500
batch_size = 128
hidden_size = 10
'''
# MNIST dataset
dataset = dsets.MNIST(root='../MINIST/MNIST',
train=True,
transform=transforms.ToTensor(),
download=True)
'''
class AnimalData(Dataset): #继承Dataset
def __init__(self, root_dir, transform=None): #__init__是初始化该类的一些基础参数
self.root_dir = root_dir #文件目录
self.transform = transform #变换
self.images = os.listdir(self.root_dir)#目录里的所有文件
def __len__(self):#返回整个数据集的大小
return len(self.images)
def __getitem__(self,index):#根据索引index返回dataset[index]
image_index = self.images[index]#根据索引index获取该图片
img_path = os.path.join(self.root_dir, image_index)#获取索引为index的图片的路径名
sample = Image.open(img_path).convert('L')# 读取该图片
if self.transform:
sample = self.transform(sample)#对样本进行变换
return sample, 1 #返回该样本
dir_path = '/mnt/sdc/cxdu/cat_pic/tesst'
transform = transforms.Compose([transforms.Resize(size = 256), transforms.CenterCrop(size = 224), transforms.Grayscale(num_output_channels=1), transforms.ToTensor()])
dataset = AnimalData(dir_path, transform)
# Data loader
data_loader = torch.utils.data.DataLoader(dataset=dataset,
batch_size=batch_size,
shuffle=True)
def to_var(x):
if torch.cuda.is_available():
x = x.cuda()
return Variable(x)
class Autoencoder(nn.Module):
#def __init__(self, in_dim=784, h_dim=400):
def __init__(self, in_dim, n_hidden_1, n_hidden_2, n_hidden_3, h_dim):
super(Autoencoder, self).__init__()
self.enlayer1 = nn.Sequential(nn.Linear(in_dim, n_hidden_1), nn.ReLU(True))
self.enlayer2 = nn.Sequential(nn.Linear(n_hidden_1, n_hidden_2), nn.ReLU(True))
self.enlayer3 = nn.Sequential(nn.Linear(n_hidden_2, n_hidden_3), nn.ReLU(True))
self.enlayer4 = nn.Sequential(nn.Linear(n_hidden_3, h_dim), nn.ReLU(True))
self.delayer1 = nn.Sequential(nn.Linear(h_dim, n_hidden_3), nn.Sigmoid())
self.delayer2 = nn.Sequential(nn.Linear(n_hidden_3, n_hidden_2), nn.Sigmoid())
self.delayer3 = nn.Sequential(nn.Linear(n_hidden_2, n_hidden_1), nn.Sigmoid())
self.delayer4 = nn.Sequential(nn.Linear(n_hidden_1, in_dim), nn.Sigmoid())
def forward(self, x):
x1 = self.enlayer1(x)
x2 = self.enlayer2(x1)
x3 = self.enlayer3(x2)
x4 = self.enlayer4(x3)
y1 = self.delayer1(x4)
y2 = self.delayer2(y1)
y3 = self.delayer3(y2)
out = self.delayer4(y3)
return out
ae = Autoencoder(in_dim=50176, n_hidden_1=5000, n_hidden_2=500, n_hidden_3=50, h_dim=hidden_size)
if torch.cuda.is_available():
ae.cuda()
criterion = nn.BCELoss()
#This is used for measuring the error of a reconstruction in for example an auto-encoder. Note that the targets yy should be numbers between 0 and 1.
optimizer = torch.optim.Adam(ae.parameters(), lr=0.001)
iter_per_epoch = len(data_loader)
data_iter = iter(data_loader)
# save fixed inputs for debugging
fixed_x, _ = next(data_iter)
# print(fixed_x.shape)
x1, y1 = fixed_x.shape[2], fixed_x.shape[3]
torchvision.utils.save_image(Variable(fixed_x).data.cpu(), './result/real_images.png')
fixed_x = to_var(fixed_x.view(fixed_x.size(0), -1))
for epoch in range(num_epochs):
t0 = time()
for i, (images, _) in enumerate(data_loader):
# flatten the image
images = to_var(images.view(images.size(0), -1))
print(images.shape)
out = ae(images)
loss = criterion(out, images)
optimizer.zero_grad()
loss.backward()
optimizer.step()
if (i+1) % 100 == 0:
epoch_index = int((epoch+1)/num_epochs)
#print(dataset)
#print(len(dataset))
iter_index = int((i+1)/len(dataset)*batch_size)
print ('Epoch [%d/%d], Iter [%d/%d] Loss: %.4f Time: %.2fs'%(epoch+1, num_epochs, i+1, len(dataset)//batch_size, loss.data, time()-t0))
# save the reconstructed images
reconst_images = ae(fixed_x)
reconst_images = reconst_images.view(reconst_images.size(0), 1, x1, y1)
torchvision.utils.save_image(reconst_images.data.cpu(), './data/reconst_images_%d.png' % (epoch+1))