[PyTorch 学习笔记] 8.3 GAN(生成对抗网络)简介

本章代码:

这篇文章主要介绍了生成对抗网络(Generative Adversarial Network),简称 GAN。

GAN 可以看作是一种可以生成特定分布数据的模型。

下面的代码是使用 Generator 来生成人脸图像,Generator 已经训练好保存在 pkl 文件中,只需要加载参数即可。由于模型是在多 GPU 的机器上训练的,因此加载参数后需要使用remove_module()函数来修改state_dict中的key

def remove_module(state_dict_g):
    # remove module.
    from collections import OrderedDict

    new_state_dict = OrderedDict()
    for k, v in state_dict_g.items():
        namekey = k[7:] if k.startswith('module.') else k
        new_state_dict[namekey] = v

    return new_state_dict

把随机的高斯噪声输入到模型中,就可以得到人脸输出,最后进行可视化。全部代码如下:

import os
import torch.utils.data
import torchvision.transforms as transforms
import torchvision.utils as vutils
import numpy as np
import matplotlib.pyplot as plt
from common_tools import set_seed
from torch.utils.data import DataLoader
from my_dataset import CelebADataset
from dcgan import Discriminator, Generator
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def remove_module(state_dict_g):
    # remove module.
    from collections import OrderedDict

    new_state_dict = OrderedDict()
    for k, v in state_dict_g.items():
        namekey = k[7:] if k.startswith('module.') else k
        new_state_dict[namekey] = v

    return new_state_dict

set_seed(1)  # 设置随机种子

# config
path_checkpoint = os.path.join(BASE_DIR, "gan_checkpoint_14_epoch.pkl")
image_size = 64
num_img = 64
nc = 3
nz = 100
ngf = 128
ndf = 128

d_transforms = transforms.Compose([transforms.Resize(image_size),
                   transforms.CenterCrop(image_size),
                   transforms.ToTensor(),
                   transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
               ])

# step 1: data
fixed_noise = torch.randn(num_img, nz, 1, 1, device=device)

flag = 0
# flag = 1
if flag:
    z_idx = 0
    single_noise = torch.randn(1, nz, 1, 1, device=device)
    for i in range(num_img):
        add_noise = single_noise
        add_noise = add_noise[0, z_idx, 0, 0] + i*0.01
        fixed_noise[i, ...] = add_noise

# step 2: model
net_g = Generator(nz=nz, ngf=ngf, nc=nc)
# net_d = Discriminator(nc=nc, ndf=ndf)
checkpoint = torch.load(path_checkpoint, map_location="cpu")

state_dict_g = checkpoint["g_model_state_dict"]
state_dict_g = remove_module(state_dict_g)
net_g.load_state_dict(state_dict_g)
net_g.to(device)
# net_d.load_state_dict(checkpoint["d_model_state_dict"])
# net_d.to(device)

# step3: inference
with torch.no_grad():
    fake_data = net_g(fixed_noise).detach().cpu()
img_grid = vutils.make_grid(fake_data, padding=2, normalize=True).numpy()
img_grid = np.transpose(img_grid, (1, 2, 0))
plt.imshow(img_grid)
plt.show()

输出如下:


下面对 GAN 的网络结构进行讲解

Generator 接受随机噪声 $z$ 作为输入,输出生成的数据 $G(z)$。Generator 的目标是让生成数据和真实数据的分布越接近。Discriminator 接收 $G(z)$ 和随机选取的真实数据 $x$,目标是分类真实数据和生成数据,属于 2 分类问题。Discriminator 的目标是把它们二者之间分开。这里体现了对抗的思想,也就是 Generator 要欺骗 Discriminator,而 Discriminator 要识别 Generator。

GAN 的训练和监督学习训练模式的差异

在监督学习的训练模式中,训练数经过模型得到输出值,然后使用损失函数计算输出值与标签之间的差异,根据差异值进行反向传播,更新模型的参数,如下图所示。


在 GAN 的训练模式中,Generator 接收随机数得到输出值,目标是让输出值的分布与训练数据的分布接近,但是这里不是使用人为定义的损失函数来计算输出值与训练数据分布之间的差异,而是使用 Discriminator 来计算这个差异。需要注意的是这个差异不是单个数字上的差异,而是分布上的差异。如下图所示。

# GAN 的训练
  1. 首先固定 Generator,训练 Discriminator。

    • 输入:真实数据 x x x,Generator 生成的数据 G ( z ) G(z) G(z)
    • 输出:二分类概率

    从噪声分布中随机采样噪声 z z z,经过 Generator 生成 G ( z ) G(z) G(z) G ( z ) G(z) G(z) x x x 输入到 Discriminator 得到 D ( x ) D(x) D(x) D ( G ( z ) ) D(G(z)) D(G(z)),损失函数为 1 m ∑ i = 1 m [ log ⁡ D ( x ( i ) ) + log ⁡ ( 1 − D ( G ( z ( i ) ) ) ) ] \frac{1}{m} \sum_{i=1}^{m}\left[\log D\left(\boldsymbol{x}^{(i)}\right)+\log \left(1-D\left(G\left(\boldsymbol{z}^{(i)}\right)\right)\right)\right]

  • 1
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值