一、什么是 GAN?
GAN,全称 Generative Adversarial Network(生成对抗网络),是一种能够生成逼真数据的深度学习模型。它由 Ian Goodfellow 在 2014 年提出,主要用于生成高质量的图像、视频、音频等数据。
GAN 的核心思想是让两个神经网络相互对抗,一个负责生成数据,另一个负责判断数据真伪,最终逼近真实数据的分布。
二、GAN的基本结构
GAN 由两个部分组成:
- 生成器(Generator, G):负责创造新的数据,例如生成一张看起来很真实的图片。
- 判别器(Discriminator, D):负责区分输入的数据是真实的(来自训练集)还是生成器生成的(假数据)。
二者不断相互竞争:
- G的目标:欺骗判别器,让判别器认为其生成的数据是真实的。
- D的目标:尽可能正确地区分真实数据和伪造数据。
三、GAN的训练过程
GAN的训练过程可简单描述为:
- 生成器随机“画”一张假图片(比如猫的照片)。
- 判别器检查图片是真是假,如果是假的,它会给生成器反馈:“你这画得不像!”
- 生成器根据反馈改进自己的作品,努力让下一次生成的图片更逼真。
- 判别器同步提升能力,更精准地辨别真假。
- 两者不断互相对抗,最终生成器能创造出几乎以假乱真的图像。
训练的关键目标是:
- 判别器的损失函数:尽可能正确区分真实数据和伪造数据。
- 生成器的损失函数:让判别器误判,使假数据看起来像真数据。
当训练达到平衡时,即损失均达到0.5左右,判别器无法再轻易区分真假数据,说明生成器已经学会了生成高质量的内容。
四、GAN超详细代码讲解
GAN代码选自github开源项目PyTorch-GAN/implementations/gan/gan.py at master · eriklindernoren/PyTorch-GAN ,生成类似于MNIST手写数字的数据。
4.1 导入必要的库
#argparse:解析命令行参数,方便调整训练超参数
import argparse
#用于创建文件夹、存储图片
import os
import numpy as np
import math
#torchvision包含 MNIST 数据集的下载、预处理及图像保存
import torchvision.transforms as transforms
from torchvision.utils import save_image
from torch.utils.data import DataLoader
from torchvision import datasets
from torch.autograd import Variable
#torch.nn:用于定义神经网络结构
import torch.nn as nn
import torch.nn.functional as F
import torch
argparse:解析命令行参数,方便调整训练超参数。
os:用于创建文件夹、存储图片。
numpy和math:用于数学计算。
torchvision:包含 MNIST 数据集的下载、预处理及图像保存。
torch:PyTorch 的核心库。
torch.nn:用于定义神经网络结构。
4.2 解析命令行参数
#创建一个命令行参数解析器,用于从终端运行脚本时传入不