网络全用全连接层nn.Linear()
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
import torch.optim as optim
import os
import math
class VAE(nn.Module):
def __init__(self):
super(VAE, self).__init__()
self.E_fc1 = nn.Linear(28*28, 256)
self.fc_mu = nn.Linear(256, 8)
self.fc_var = nn.Linear(256, 8)
self.z_fc = nn.Linear(8, 256)
self.D_fc1 = nn.Linear(256, 28*28)
def repara(self, mu, var):
var = torch.exp(0.5*var)
epsilon = torch.randn_like(var)
z = mu + epsilon * var
return z
def decoder(self, z):
out = F.relu(self.z_fc(z))
out = torch.sigmoid(self.D_fc1(out))
return out
def forward(self, x):
out = F.relu(self.E_fc1(x.view(-1, 28*28)))
mu = self.fc_mu(out)
var = self.fc_var(out)
z = self.repara(mu, var)
out = self.decoder(z)
return out, mu, var
if __name__ == '__main__':
epoch = 100
batch_size = 64
model = VAE().cuda()
optimizer = optim.Adam(model.parameters(), lr=1e-4)
train_dataset = torchvision.datasets.MNIST(root='./data',
train=True,
transform=transforms.ToTensor(),
download=True)
test_dataset = torchvision.datasets.MNIST(root='./data',
train=False,
transform=transforms.ToTensor(),
download=True)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
batch_size=batch_size,
shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
batch_size=batch_size,
shuffle=False)
iter = 0
model.train()
for j in range(epoch):
for _, data in enumerate(train_loader):
img, labels = data
img = img.cuda()
out, mu, var = model(img)
rec_loss = F.binary_cross_entropy(out, img.view(-1, 28*28), reduction='sum')
kl_loss = -0.5 * torch.sum(1 + var - mu.pow(2) - var.exp())
loss = rec_loss + kl_loss
optimizer.zero_grad()
loss.backward()
optimizer.step()
if iter % 10 == 0:
print("epoch {} iter {}: kl_loss={:.4f}, rec_loss={:.4f}, loss={:.4f}".format(j, iter, kl_loss, rec_loss, loss))
if iter % 500 == 0:
model.eval()
with torch.no_grad():
if not os.path.exists('./vae_val_result'):
os.mkdir('./vae_val_result')
for i, test_data in enumerate(test_loader):
test_img, test_labels = test_data
test_img = test_img.view(test_img.size()[0], -1).cuda()
test_out, _, _ = model(test_img)
if i == 0:
test_out = test_out.view(-1, 1, 28, 28)
test_img = test_img.view(-1, 1, 28, 28)
save_data = torch.cat([test_out[9:17], test_img[9:17]])
torchvision.utils.save_image(save_data, './vae_val_result/epoch'+str(j)+'iter'+str(iter)+'.jpg')
iter += 1
print("Training is over!")
with torch.no_grad():
z = torch.randn(64, 8).cuda()
sample_data = model.decoder(z)
torchvision.utils.save_image(sample_data.view(-1, 1, 28, 28), './vae_val_result/ sample_img.jpg')
print("Saving successfully sample_image!")
训练100个epoch的测试结果(上面一行为重构结果,下面一行为ground truth):
采样结果:
从结果可以看出全连接网络的VAE重构结果的效果已经比较好了,采样结果一般。若改用卷积网络,生成结果应该会有很大提升。