一、GAN模型介绍
GAN:Generative adversarial network生成对抗网络。
GAN框架让一个深度学习模型学习训练数据分布,从而生成具有同分布的类似数据。
GAN由两个不同的模型组成,一个是生成模型G(Generator),一个是鉴别模型D(Discriminator)。其中,G的作用是产生fake图像使其的分布与训练图像相似; D的作用是来判断这个fake图像与真正的图像是否相同。
训练过程中,G通过产生越来越好的fake图像,来不断试图去打败D;同时D也是如此。这个训练在当生成器生成看起来像是直接来自训练数据的完美赝品时,判别器总是猜测生成器输出为真或假的概率为50%时达到平衡。
此次实验采用DCGAN作为模型架构。DCGAN是将CNN与GAN的一种结合,将GAN的G和D换成了两个CNN。
二、训练过程
utils.py
from torch.utils.data import Dataset, DataLoader
import cv2
import os
import glob
import torchvision.transforms as transforms
import random
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
class FaceDataset(Dataset):
def __init__(self, fnames, transform):
self.transform = transform
self.fnames = fnames
self.num_samples = len(self.fnames)
def __getitem__(self,idx):
fname = self.fnames[idx]
img = cv2.imread(fname)
img = self.BGR2RGB(img)
img = self.transform(img)
return img
def __len__(self):
return self.num_samples
def BGR2RGB(self,img):
return cv2.cvtColor(img,cv2.COLOR_BGR2RGB)
def get_dataset(root):
fnames = glob.glob(os.path.join(root, '*'))
transform = transforms.Compose(
[transforms.ToPILImage(),
transforms.Resize((64, 64)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5] * 3, std=[0.5] * 3) ] )
dataset = FaceDataset(fnames, transform)
return dataset
def same_seeds(seed):
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
def weights_init(m):
classname = m.__class__.__name__
if classname.find('Conv') != -1:
m.weight.data.normal_(0.0, 0.02)
elif classname.find('BatchNorm') != -1:
m.weight.data.normal_(1.0, 0.02)
m.bias.data.fill_(0)
class Generator(nn.Module):
"""
input (N, in_dim)
output (N, 3, 64, 64)
"""
def __init__(self, in_dim, dim=64):
super(Generator, self).__init__()
def dconv_bn_relu(in_dim, out_dim):
return nn.Sequential(
nn.ConvTranspose2d(in_dim, out_dim, 5, 2,
padding=2, output_padding=1, bias=False),
nn.BatchNorm2d(out_dim),
nn.ReLU())
self.l1 = nn.Sequential(
nn.Linear(in_dim, dim * 8 * 4 * 4, bias=False),
nn.BatchNorm1d(dim * 8 * 4 * 4),
nn.ReLU())
self.l2_5 = nn.Sequential(
dconv_bn_relu(dim * 8, dim * 4),
dconv_bn_relu(dim * 4, dim * 2),
dconv_bn_relu(dim * 2, dim),
nn.ConvTranspose2d(dim, 3, 5, 2, padding=2, output_padding=1),
nn.Tanh())
self.apply(weights_init)
def forward(self, x):
y = self.l1(x)
y = y.view(y.size(0), -1, 4, 4)
y = self.l2_5(y)
return y
class Discriminator(nn.Module):
"""
input (N, 3, 64, 64)
output (N, )
"""
def __init__(self, in_dim, dim=64):
super(Discriminator, self).__init__()
def conv_bn_lrelu(in_dim, out_dim):
return nn.Sequential(
nn.Conv2d(in_dim, out_dim, 5, 2, 2),
nn.BatchNorm2d(out_dim),
nn.LeakyReLU(0.2))
self.ls = nn.Sequential(
nn.Conv2d(in_dim, dim, 5, 2, 2), nn.LeakyReLU(0.2),
conv_bn_lrelu(dim, dim * 2),
conv_bn_lrelu(dim * 2, dim * 4),
conv_bn_lrelu(dim * 4, dim * 8),
nn.Conv2d(dim * 8, 1, 4),
nn.Sigmoid())
self.apply(weights_init)
def forward(self, x):
y = self.ls(x)
y = y.view(-1)
return y
hw11.py
import torch
from torch import optim
from torch.autograd import Variable
import torchvision
from utils import *
import matplotlib.pyplot as plt
if __name__ == '__main__':
# 超参数
batch_size = 64
z_dim = 100
lr = 1e-4
n_epoch = 10
save_dir = 'logs'
os.makedirs(save_dir, exist_ok=True)
# 建立模型
G = Generator(in_dim=z_dim).cuda()
D = Discriminator(3).cuda()
G.train()
D.train()
# loss criterion
criterion = nn.BCELoss()
# optimizer
opt_D = torch.optim.Adam(D.parameters(), lr=lr, betas=(0.5, 0.999))
opt_G = torch.optim.Adam(G.parameters(), lr=lr, betas=(0.5, 0.999))
same_seeds(0)
# 导入数据
dataset = get_dataset('faces')
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4)
# 开始训练
z_sample = Variable(torch.randn(100, z_dim)).cuda()
for e, epoch in enumerate(range(n_epoch)):
for i, data in enumerate(dataloader):
imgs = data
imgs = imgs.cuda()
bs = imgs.size(0)
""" 训练D网络 """
z = Variable(torch.randn(bs, z_dim)).cuda()
r_imgs = Variable(imgs).cuda()
f_imgs = G(z)
# label
r_label = torch.ones((bs)).cuda()
f_label = torch.zeros((bs)).cuda()
r_logit = D(r_imgs.detach())
f_logit = D(f_imgs.detach())
# 计算 loss
r_loss = criterion(r_logit, r_label)
f_loss = criterion(f_logit, f_label)
loss_D = (r_loss + f_loss) / 2
# 更新模型参数
D.zero_grad()
loss_D.backward()
opt_D.step()
""" 训练G网络 """
z = Variable(torch.randn(bs, z_dim)).cuda()
f_imgs = G(z)
f_logit = D(f_imgs)
# 计算 loss
loss_G = criterion(f_logit, r_label)
# 更新模型参数
G.zero_grad()
loss_G.backward()
opt_G.step()
# 打印训练过程中的参数信息
print(
f'\rEpoch [{epoch + 1}/{n_epoch}] {i + 1}/{len(dataloader)} Loss_D: {loss_D.item():.4f} Loss_G: {loss_G.item():.4f}',
end='')
G.eval()
f_imgs_sample = (G(z_sample).data + 1) / 2.0
filename = os.path.join(save_dir, f'Epoch_{epoch + 1:03d}.jpg')
torchvision.utils.save_image(f_imgs_sample, filename, nrow=10)
print(f' | Save some samples to {filename}.')
# 显示生成图像
grid_img = torchvision.utils.make_grid(f_imgs_sample.cpu(), nrow=10)
plt.figure(figsize=(10, 10))
plt.imshow(grid_img.permute(1, 2, 0))
plt.show()
G.train()
if (e + 1) % 5 == 0:
torch.save(G.state_dict(), 'dcgan_g.pth')
torch.save(D.state_dict(), 'dcgan_d.pth')
epoch | 训练中间过程中generator的结果 |
---|---|
1 | ![]() |
2 | ![]() |
3 | ![]() |
4 | ![]() |
5 | ![]() |
6 | ![]() |
7 | ![]() |
8 | ![]() |
9 | ![]() |
10 | ![]() |
如理论分析得到的结果所示,图片确实是在不断变得清晰以及更加靠近数据集内真实图片的。
三、实验结果
gerenete.py
import torch
from torch import optim
from torch.autograd import Variable
import torchvision
import matplotlib.pyplot as plt
from utils import *
if __name__ == '__main__':
z_dim = 100
# 加载训练好的模型
G = Generator(z_dim)
G.load_state_dict(torch.load('dcgan_g.pth'))
G.eval()
G.cuda()
# 生成图片并进行保存
n_output = 20
z_sample = Variable(torch.randn(n_output, z_dim)).cuda()
imgs_sample = (G(z_sample).data + 1) / 2.0
save_dir = 'logs'
filename = os.path.join(save_dir, f'result.jpg')
torchvision.utils.save_image(imgs_sample, filename, nrow=10)
# 显示图片
grid_img = torchvision.utils.make_grid(imgs_sample.cpu(), nrow=10)
plt.figure(figsize=(10, 10))
plt.imshow(grid_img.permute(1, 2, 0))
plt.show()