本章代码:
这篇文章主要介绍了生成对抗网络(Generative Adversarial Network),简称 GAN。
GAN 可以看作是一种可以生成特定分布数据的模型。
下面的代码是使用 Generator 来生成人脸图像,Generator 已经训练好保存在 pkl 文件中,只需要加载参数即可。由于模型是在多 GPU 的机器上训练的,因此加载参数后需要使用remove_module()
函数来修改state_dict
中的key
。
def remove_module(state_dict_g):
# remove module.
from collections import OrderedDict
new_state_dict = OrderedDict()
for k, v in state_dict_g.items():
namekey = k[7:] if k.startswith('module.') else k
new_state_dict[namekey] = v
return new_state_dict
把随机的高斯噪声输入到模型中,就可以得到人脸输出,最后进行可视化。全部代码如下:
import os
import torch.utils.data
import torchvision.transforms as transforms
import torchvision.utils as vutils
import numpy as np
import matplotlib.pyplot as plt
from common_tools import set_seed
from torch.utils.data import DataLoader
from my_dataset import CelebADataset
from dcgan import Discriminator, Generator
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def remove_module(state_dict_g):
# remove module.
from collections import OrderedDict
new_state_dict = OrderedDict()
for k, v in state_dict_g.items():
namekey = k[7:] if k.startswith('module.') else k
new_state_dict[namekey] = v
return new_state_dict
set_seed(1) # 设置随机种子
# config
path_checkpoint = os.path.join(BASE_DIR, "gan_checkpoint_14_epoch.pkl")
image_size = 64
num_img = 64
nc = 3
nz = 100
ngf = 128
ndf = 128
d_transforms = transforms.Compose([transforms.Resize(image_size),
transforms.CenterCrop(image_size),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])
# step 1: data
fixed_noise = torch.randn(num_img, nz, 1, 1, device=device)
flag = 0
# flag = 1
if flag:
z_idx = 0
single_noise = torch.randn(1, nz, 1, 1, device=device)
for i in range(num_img):
add_noise = single_noise
add_noise = add_noise[0, z_idx, 0, 0] + i*0.01
fixed_noise[i, ...] = add_noise
# step 2: model
net_g = Generator(nz=nz, ngf=ngf, nc=nc)
# net_d = Discriminator(nc=nc, ndf=ndf)
checkpoint = torch.load(path_checkpoint, map_location="cpu")
state_dict_g = checkpoint["g_model_state_dict"]
state_dict_g = remove_module(state_dict_g)
net_g.load_state_dict(state_dict_g)
net_g.to(device)
# net_d.load_state_dict(checkpoint["d_model_state_dict"])
# net_d.to(device)
# step3: inference
with torch.no_grad():
fake_data = net_g(fixed_noise).detach().cpu()
img_grid = vutils.make_grid(fake_data, padding=2, normalize=True).numpy()
img_grid = np.transpose(img_grid, (1, 2, 0))
plt.imshow(img_grid)
plt.show()
输出如下:
下面对 GAN 的网络结构进行讲解
Generator 接受随机噪声 $z$ 作为输入,输出生成的数据 $G(z)$。Generator 的目标是让生成数据和真实数据的分布越接近。Discriminator 接收 $G(z)$ 和随机选取的真实数据 $x$,目标是分类真实数据和生成数据,属于 2 分类问题。Discriminator 的目标是把它们二者之间分开。这里体现了对抗的思想,也就是 Generator 要欺骗 Discriminator,而 Discriminator 要识别 Generator。
GAN 的训练和监督学习训练模式的差异
在监督学习的训练模式中,训练数经过模型得到输出值,然后使用损失函数计算输出值与标签之间的差异,根据差异值进行反向传播,更新模型的参数,如下图所示。
在 GAN 的训练模式中,Generator 接收随机数得到输出值,目标是让输出值的分布与训练数据的分布接近,但是这里不是使用人为定义的损失函数来计算输出值与训练数据分布之间的差异,而是使用 Discriminator 来计算这个差异。需要注意的是这个差异不是单个数字上的差异,而是分布上的差异。如下图所示。
# GAN 的训练
-
首先固定 Generator,训练 Discriminator。
- 输入:真实数据 x x x,Generator 生成的数据 G ( z ) G(z) G(z)
- 输出:二分类概率
从噪声分布中随机采样噪声 z z z,经过 Generator 生成 G ( z ) G(z) G(z)。 G ( z ) G(z) G(z) 和 x x x 输入到 Discriminator 得到 D ( x ) D(x) D(x) 和 D ( G ( z ) ) D(G(z)) D(G(z)),损失函数为 1 m ∑ i = 1 m [ log D ( x ( i ) ) + log ( 1 − D ( G ( z ( i ) ) ) ) ] \frac{1}{m} \sum_{i=1}^{m}\left[\log D\left(\boldsymbol{x}^{(i)}\right)+\log \left(1-D\left(G\left(\boldsymbol{z}^{(i)}\right)\right)\right)\right]