GAN网络(理论):https://blog.csdn.net/qq_39862223/article/details/114262928
GAN网络(实验):https://blog.csdn.net/qq_39862223/article/details/114283108
1 GAN网络实验
github上有大量的GAN网络及其变体实验,大多是基于TensorFlow和Keras的,大部分都是对版本有很高的限制,所以我找到了GAN网络的pytorch版本实验:PyTorch-GAN
基于Pytorch版本的GAN网络部署起来更加简洁、快捷。以下是使用方法:
1.1 下载源文件
git clone https://github.com/eriklindernoren/PyTorch-GAN
在源文件/implementations中,可以看到GAN网络及其绝大部分变体的实现代码,包括CGAN、ACGAN、DCGAN等
如果你需要对某一个网络进行试验、复现时,按照之后的步骤进行实现即可。
1.2 安装必要的插件
在跑通代码之前,一些必要的插件需要被安装
- torch>=0.4.0
- torchvision
- matplotlib
- numpy
- scipy
- pillow
- urllib3
- scikit-image
可以直接使用命令进行安装
pip3 install -r requirements.txt
结果如下所示:
安装成功!
1.3 调试代码
这里我仅对最基础的GAN网络进行调试运行,其他变体方法类似。
将这一部分的default设置成适合你电脑的值,比如batch_size = 2等,否则电脑负荷过大,跑不动。
之后对GAN.py直接运行即可
python3 gan.py
运行过程可以看到判别器和生成器的loss变化
1.4 运行结果
运行结果会存储在你指定的文件夹中
上图所示,即为每一次生成器生成的手写数字
上图分别是第一次生成和第100次生成,可以看到生成器有很好的优化和改变。
2 GAN应用于手写数字识别
这一部分对代码进行讲解,我使用mnist手写数字来做数据集,通过生成对抗网络我们希望生成一些“以假乱真”的手写字体。为了加快训练过程,不使用卷积网络来做判别器,使用简单的多层网络来进行判别。
2.1 定义判别器和生成器
- Discriminator Network
# 定义判别器 Discriminator使用多层网络来作为判别器
# 将图片28x28展开成784,然后通过多层感知器,中间经过斜率设置为0.2的LeakyReLU激活函数,
# 最后接sigmoid激活函数得到一个0到1之间的概率进行二分类。
class discriminator(nn.Module):
def __init__(self):
super(discriminator, self).__init__()
self.dis = nn.Sequential(
nn