GAN网络(实验)

这篇博客介绍了基于PyTorch的GAN网络实验,包括如何下载源文件、安装插件、调试代码以及运行结果。作者以手写数字识别为例,详细说明了判别器和生成器的定义、训练过程,并展示了生成图片的进化。
摘要由CSDN通过智能技术生成

GAN网络(理论):https://blog.csdn.net/qq_39862223/article/details/114262928
GAN网络(实验):https://blog.csdn.net/qq_39862223/article/details/114283108

1 GAN网络实验

github上有大量的GAN网络及其变体实验,大多是基于TensorFlow和Keras的,大部分都是对版本有很高的限制,所以我找到了GAN网络的pytorch版本实验:PyTorch-GAN
基于Pytorch版本的GAN网络部署起来更加简洁、快捷。以下是使用方法:

1.1 下载源文件

git clone https://github.com/eriklindernoren/PyTorch-GAN

在这里插入图片描述
在源文件/implementations中,可以看到GAN网络及其绝大部分变体的实现代码,包括CGAN、ACGAN、DCGAN等
在这里插入图片描述
如果你需要对某一个网络进行试验、复现时,按照之后的步骤进行实现即可。

1.2 安装必要的插件

在跑通代码之前,一些必要的插件需要被安装

  • torch>=0.4.0
  • torchvision
  • matplotlib
  • numpy
  • scipy
  • pillow
  • urllib3
  • scikit-image
    可以直接使用命令进行安装

pip3 install -r requirements.txt

结果如下所示:
在这里插入图片描述
安装成功!

1.3 调试代码

这里我仅对最基础的GAN网络进行调试运行,其他变体方法类似。
在这里插入图片描述
将这一部分的default设置成适合你电脑的值,比如batch_size = 2等,否则电脑负荷过大,跑不动。
之后对GAN.py直接运行即可

python3 gan.py

在这里插入图片描述
运行过程可以看到判别器和生成器的loss变化

1.4 运行结果

运行结果会存储在你指定的文件夹中
在这里插入图片描述
上图所示,即为每一次生成器生成的手写数字
在这里插入图片描述
在这里插入图片描述
上图分别是第一次生成和第100次生成,可以看到生成器有很好的优化和改变。

2 GAN应用于手写数字识别

这一部分对代码进行讲解,我使用mnist手写数字来做数据集,通过生成对抗网络我们希望生成一些“以假乱真”的手写字体。为了加快训练过程,不使用卷积网络来做判别器,使用简单的多层网络来进行判别。

2.1 定义判别器和生成器
  • Discriminator Network
# 定义判别器 Discriminator使用多层网络来作为判别器
# 将图片28x28展开成784,然后通过多层感知器,中间经过斜率设置为0.2的LeakyReLU激活函数,
# 最后接sigmoid激活函数得到一个0到1之间的概率进行二分类。
class discriminator(nn.Module):
    def __init__(self):
        super(discriminator, self).__init__()
        self.dis = nn.Sequential(
            nn
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值