import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.optim as optim
from torchvision import transforms,datasets
import torch.nn.functional as F
import os
import scipy
import numpy as np
from scipy import misc
import math
batch_size = 64
latent_vector = 32
intermediate_vector = 256
num_class = 10
det = 1e-10
lamb = 2.5
train_data = datasets.MNIST(root='./data/', train=True, transform=transforms.ToTensor(), download= True)
test_data = datasets.MNIST(root= './data/', train=True, transform=transforms.ToTensor(), download= True)
train_loader = torch.utils.data.DataLoader(dataset=train_data, batch_size= batch_size, shuffle= True)
test_loader = torch.utils.data.DataLoader(dataset=test_data, batch_size= batch_size, shuffle=True)
class VAE(nn.Module):
def __init__(self):
super(VAE, self).__init__()
self.encoder = nn.Sequential(
nn.Conv2d(1, 32, 3, padding=1, stride=1),
nn.LeakyReLU(negative_slope=0.2),
nn.Conv2d(32, 32, 3, padding=1, stride=2),
nn.LeakyReLU(negative_slope=0.2),
nn.Conv2d(32, 64, 3, padding=1, stride=1),
nn.LeakyReLU(negative_slope=0.2),
nn.Conv2d(64, 64, 3, padding=1, stride=2), # 7*7*64
nn.LeakyReLU(negative_slope=0.2),
)
self.fc_mu = nn.Linear(7*7*64, latent_vector)
self.fc_logvar = nn.Linear(7*7*64, latent_vector)
self.fc = nn.Linear(latent_vector, 7*7*64)
self.decoder = nn.Sequential(
nn.ConvTranspose2d(64, 64, 4, padding=1, stride=2),
nn.LeakyReLU(negative_slope=0.2),
nn.ConvTranspose2d(64, 32, 3, padding=1, stride=1),
nn.LeakyReLU(negative_slope=0.2),
nn.ConvTranspose2d(32, 32, 4, padding=1, stride=2),
nn.LeakyReLU(negative_slope=0.2),
nn.ConvTranspose2d(32, 1, 3, padding=1, stride=1),
nn.Sigmoid()
)
def Reparameter(self, mu, logvar):
parameter = Variable(torch.randn(mu.size(0), mu.size(1)))
return parameter * torch.exp(logvar/2) + mu
def forward(self, x):
mu = self.encoder(x)
logvar = self.encoder(x)
mu = self.fc_mu(mu.view(mu.size(0), -1))
logvar = self.fc_logvar(logvar.view(logvar.size(0), -1))
z = self.Reparameter(mu, logvar)
vector = self.fc(z).view(z.size(0), 64, 7, 7)
return self.decoder(vector), mu, logvar
model = VAE()
MSE_Loss = nn.MSELoss(size_average=False)
#MSE_Loss = F.binary_cross_entropy(reduction='sum')
def loss_function(input, output, mu, logvar):
Mse = MSE_Loss(input, output)
# Mse = 0.5 * torch.mean((input - output).pow(2), 0)
KL_loss = 0.5 * torch.sum(-logvar + mu.pow(2) + logvar.exp() - 1)
# KL_loss = -0.5 * (temp_logvar - z_se.pow(2))
#
# KL_loss = torch.mean(torch.tensordot(torch.unsqueeze(y, 1), KL_loss), 0)
#
# cat_loss = torch.mean(y * torch.log(y + det), 0)
#
# return lamb * torch.sum(Mse) + torch.sum(KL_loss) + torch.sum(cat_loss)
return Mse + KL_loss
def save_image(output, size, path, Color):
h, w = output.shape[1], output.shape[2]
if Color is True:
image = np.zeros((w * size[0], h * size[1], 3))
else:
image = np.zeros((w * size[0], h * size[1]))
for index, data in enumerate (output):
i = index % size[0]
j = math.floor(index / size[1])
if Color is True:
image[h*j : h*j+h, w*i : w*i+w, :] = data
else:
image[h*j : h*j+j, w*i : w*i+w] = data
scipy.misc.toimage((image*255), cmin=0, cmax=255).save(path)
def rescale_image(image):
return (image/1.5+0.5)*255
optimizer = optim.SGD(model.parameters(), lr= 0.0001)
def train():
for epoch in range(1,10):
for i, (data, _) in enumerate (train_loader):
tensor_data = Variable(data)
output, mu, logvar= model(tensor_data)
optimizer.zero_grad()
loss = loss_function(tensor_data, output, mu, logvar)
loss.backward()
optimizer.step()
if i % 50 == 0:
if not os.path.exists("./image"):
os.mkdir("./image")
np_output = output.detach().numpy()
np_output = np_output.swapaxes(1,2).swapaxes(2,3)
save_image(np_output, [8,8], './image/image_{}.png'.format(i), True)
print("loss={}".format(loss))
train()
VAE pytorch
最新推荐文章于 2024-03-30 19:58:30 发布