DCGAN生成cifar10的图片

DCGAN是一种稳定的深度卷积生成对抗网络,其discriminator能提取有效的图像特征,适合图像分类。通过训练,DCGAN学习到有意义的滤波器,并保持从潜在空间到图像的连续性。模型设计包括全卷积网络,避免全连接层,广泛使用批量归一化,以及特定的激活函数配置。计划将预训练的分类器模型参数应用于判别器。
摘要由CSDN通过智能技术生成

DCGAN(Deep Convolutional GAN):

DCGAN的贡献

  • 提出了一类基于卷积神经网络的GANs,称为DCGAN,它在多数情况下训练是稳定的。
  • 与其他非监督方法相比,DCGAN的discriminator提取到的图像特征更有效,更适合用于图像分类任务。
  • 通过训练,DCGAN能学到有意义的 filters。
  • DCGAN的generator能够保持latentspace到image的“连续性”。

DCGAN model
实际上,DCGAN是一类GAN的简称,满足以下设计要求(这些要求更像是一些tricks)的GAN网络都可以称为DCGAN模型。

  • 采用全卷积神经网络。不使用空间池化,取而代之使用带步长的卷积层(strided convolution)。这么做能让网络自己学习更合适的空间下采样方法。PS:对于generator来说,要做上采样,采用的是分数步长的卷积(fractionally-stridedconvolution);对于discriminator来说,一般采用整数步长的卷积。
  • 避免在卷积层之后使用全连接层。全连接层虽然增加了模型的稳定性,但也减缓了收敛速度。一般来说,generator的输入(噪声)采用均匀分布;discriminator的最后一个卷积层一般先摊平(flatten),然后接一个单节点的softmax。
  • 除了generator的输出层和discriminator的输入层以外,其他层都是采用batch normalization。Batch normalization能确保每个节点的输入都是均值为0,方差为1。即使是初始化很差,也能保证网络中有足够强的梯度。
  • 对于generator,输出层的激活函数采用Tanh,其它层的激活函数采用ReLU。对于discriminator,激活函数采用leaky ReLU。

文件夹路径及配置:

![在这里插入图片描述](https://img-blog.csdnimg.cn/20210321162042566.png?x-oss-process=image/watermark,type_ZmFuZ3poZW5naGVpdGk,shadow_10,text_aHR0cHM6Ly9ibG9nLmNzZG4ubmV0L0RldmlsX0I=,size_16,color_FFFFFF,t_70) data:存放数据集 output:存放输出 weights:存放保存的模型参数

gan.py:

from __future__ import print_function
import os
import random
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils



if __name__ == '__main__':
    cudnn.benchmark = True

    # 将手动种子设置为常数可获得一致的输出
    manualSeed = random.randint(1, 10000)
    print("Random Seed: ", manualSeed)
    random.seed(manualSeed)
    torch.manual_seed(manualSeed)

    # 加载数据集
    dataset = dset.CIFAR10(root="./data", download=True,
                           transform=transforms.Compose([
                               transforms.Resize(64),
                               transforms.ToTensor(),
                               transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                           ]))
    nc = 3

    dataloader = torch.utils.data.DataLoader(dataset, batch_size=128,
                                             shuffle=True, num_workers=2)

    # 检查是用GPU还是CPU
    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    # 可以获得的GPU数量
    ngpu = 1

    # 输入噪声尺寸
    nz = 100
    # 生成器过滤数量
    ngf = 64
    # 判别器过滤数量
    ndf = 64


    # 网络和网络上调用的自定义权重初始化
    def weights_init(m):
        classname = m.__class__.__name__
        if classna
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值