生成对抗网络的概念
最基本的GAN模型由一个生成器 G 和判别器 D 组成。生成器用于生成假样本,判别器用于判断样本是真实的还是假的。
- 生成器(Generator):通过机器生成数据(大部分情况下是图像),目的是“骗过”判别器
- 判别器(Discriminator):判断这张图像是真实的还是机器生成的,目的是找出生成器做的“假数据”
首先,固定判别器D,训练生成器G。让生成器不断生成假数据,然后让判别器D去判断,一开始生成器G生成的结果很容易被判别器D识别,然而随着不断的训练,生成器G效果不断提升,直到判别器无法分辩出数据的真假,也就是说这时判别器判断真假数据的概率为0.5.
然后固定生成器G,训练判别器D。当判别器无法分辩生成器生成的数据的时候,这时继续训练生成器是没有意义的。这时,可以训练判别器D,提升判别器D的性能。
数据集的显示
import matplotlib.pyplot as plt
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision import datasets
tran