GAN生成网络

浅谈GAN生成网络


操作系统:ubuntu18.04
显卡:GTX1080ti
python版本:2.7(3.7)
深度学习框架:Pytorch
QQ群加入深度学习交流群 获取更多环境配置细节和学习资料 (147960154)


大概思路为 GAN --> DCGAN --> CGAN
先谈谈对GAN的理解,GAN Generative Adversarial Networks )

  • 作用: 主要用来数据增强
  • 思想: 采用类型图灵测试的对抗性思路
  • 特点: 稳定性不够强,需要大量的经验来调参

GAN的出现引起了深度学习界的又一次学术高潮. 被杨佬称为"近十年最cool的发现"

一. GAN的原理

目的:
 通过一个随机分布(比如正态分布)的Z向量(多模态GAN:输入可以是一张图片等等)经过G生成网络生成多样性模型(符合原数据集图片的流行)的过程.

如何训练G网络呢?
怎么判断G生成的图像是否与原数据集的图像很"像"呢? 这里GAN比较巧妙的利用了类似图灵测试的思路,再建立一个D网络,在专门分辨G网络生成的图像像不像原图
(怎么样?很巧妙吧!之前我们都是用MSE来做原图与生成图的损失)

D(G)网络的工作  
D的工作简单明了,就是学会分辨真图和假(合成)图.
G的工作更简单了,就是学会让生成图更像真图.
这两个相互学习,彼此博弈,渐渐的G生成的图像就越来越像真的,D呢 自然越来越能分辨出真图和假图,结果就是G生成的图像能骗过你,D的分辨能力也强过你.自然G就可以骗过你啦

缺点:
缺点很明显了,就是训练起来不稳定,因为的确比较抽象,而且G和D的学习步调不一样的话,会导致两种结果,D太强啦,导致G生成的图像始终骗不过D,网络很难收敛.G太强啦,导致会很容易骗过D,因此会使生成的图像产生一些结构性的崩坏.

二. GAN代码实现

这里选用的是MNIST数据集,预处理将图像放缩为(64,64),采用的网络结构为DCGAN

  • 第一步: 先调用需要的包
    里面有的损失函数是调用我自己的包(懒的改啦),不懂可以进群里问我
# -*- coding: utf-8 -*-
import argparse
import os
import random
import torch,torchvision,math
from dataset import MY_XIEHE
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from loss import DiscriminatorLossLayer,GeneratorLossLayer,loss_MY_VAE
from IPython.display import HTML
from torchvision import datasets,transforms
  • 第二步: 处理数据集
# Root directory for dataset
dataroot = '/home/water/PycharmProjects/project_xiehe/VAE/mnist'

# Number of workers for dataloader
workers = 2

# Batch size during training
batch_size = 16

# Spatial size of training images. All images will be resized to this
#   size using a transformer.
image_size = 64

# Number of channels in the training images. For color images this is 3
nc = 1

# Size of z latent vector (i.e. size of generator input)
nz = 100

# Size of feature maps in generator
ngf = 64

# Size of feature maps in discriminator
ndf = 64

# Number of training epochs
num_epochs = 10

# Learning rate for optimizers
lr = 0.0002

# Beta1 hyperparam for Adam optimizers
beta1 = 0.5

compose = transforms.Compose([
                               transforms.Resize((image_size,image_size)),
                               # transforms.CenterCrop(image_size),
                               transforms.ToTensor(),
                               transforms.Normalize((0.5,), (0.5,)),
                           ])
# Create the dataloader
dataloader = torch.utils.data.DataLoader(dataset=datasets.MNIST(root='/home/water/PycharmProjects/project_xiehe/VAE/mnist',train=True,transform=compose,)
                             ,batch_size=batch_size,shuffle=True)
# dataloader = torch.utils.data.DataLoader(dataset=MY_XIEHE(root='/home/water/PycharmProjects/project_xiehe/crop1_img',transform=compose)
#                           ,batch_size=batch_size,shuffle=True)
  • 第三步: 开始定义G与D的网络结构
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()

        n = 6
        self.decoder = nn.Sequential()
        # input is Z, going into a convolution
        self.decoder.add_module('input-conv', nn.ConvTranspose2d(nz, ngf * 2 ** (n - 3), 4, 1, 0, bias=False))
        self.decoder.add_module('input-batchnorm', nn.BatchNorm2d(ngf * 2 ** (n - 3)))
        self.decoder.add_module('input-relu', nn.LeakyReLU(0.2, inplace=True))

        # state size. (ngf * 2**(n-3)) x 4 x 4

        for i in range(n - 3, 0, -1):
            self.decoder.add_module('pyramid{0}-{1}conv'.format(ngf * 2 ** i, ngf * 2 ** (i - 1)),
                                    nn.ConvTranspose2d(ngf * 2 ** i, ngf * 2 ** (i - 1), 4, 2, 1, bias=False))
            self.decoder.add_module('pyramid{0}batchnorm'.format(ngf * 2 ** (i - 1)),
                                    nn.BatchNorm2d(ngf * 2 ** (i - 1)))
            self.decoder.add_module('pyramid{0}relu'.format(ngf * 2 ** (i - 1)), nn.LeakyReLU(0.2, inplace=True))

        self.decoder.add_module('ouput-conv', nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False))
        self.decoder.add_module('output-tanh', nn.Tanh())

    def forward(self, z):
        output = self.decoder(z)
        return output

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()

        n = 6
        self.main = nn.Sequential()

        # input is (nc) x 64 x 64
        self.main.add_module('input-conv', nn.Conv2d(nc, ndf, 4, 2, 1, bias=False))
        self.main.add_module('relu', nn.LeakyReLU(0.2, inplace=True))

        # state size. (ndf) x 32 x 32
        for i in range(n - 3):
            self.main.add_module('pyramid{0}-{1}conv'.format(ngf * 2 ** (i), ngf * 2 ** (i + 1)),
                                 nn.Conv2d(ndf * 2 ** (i), ndf * 2 ** (i + 1), 4, 2, 1, bias=False))
            self.main.add_module('pyramid{0}batchnorm'.format(ngf * 2 ** (i + 1)), nn.BatchNorm2d(ndf * 2 ** (i + 1)))
            self.main.add_module('pyramid{0}relu'.format(ngf * 2 ** (i + 1)), nn.LeakyReLU(0.2, inplace=True))

        self.main.add_module('output-conv', nn.Conv2d(ndf * 2 ** (n - 3), 1, 4, 1, 0, bias=False))
        self.main.add_module('output-sigmoid', nn.Sigmoid())

    def forward(self, input):
        output = self.main(input)

        return output.view(-1, 1)

netG = Generator().cuda()
netG.apply(weights_init)
netD = Discriminator().cuda()
netD.apply(weights_init)
print netG
print netD
  • 把网络输出结果如下
Generator(
  (decoder): Sequential(
    (input-conv): ConvTranspose2d(100, 512, kernel_size=(4, 4), stride=(1, 1), bias=False)
    (input-batchnorm): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (input-relu): LeakyReLU(negative_slope=0.2, inplace)
    (pyramid512-256conv): ConvTranspose2d(512, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (pyramid256batchnorm): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (pyramid256relu): LeakyReLU(negative_slope=0.2, inplace)
    (pyramid256-128conv): ConvTranspose2d(256, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (pyramid128batchnorm): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (pyramid128relu): LeakyReLU(negative_slope=0.2, inplace)
    (pyramid128-64conv): ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (pyramid64batchnorm): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (pyramid64relu): LeakyReLU(negative_slope=0.2, inplace)
    (ouput-conv): ConvTranspose2d(64, 1, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (output-tanh): Tanh()
  )
)
Discriminator(
  (main): Sequential(
    (input-conv): Conv2d(1, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (relu): LeakyReLU(negative_slope=0.2, inplace)
    (pyramid64-128conv): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (pyramid128batchnorm): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (pyramid128relu): LeakyReLU(negative_slope=0.2, inplace)
    (pyramid128-256conv): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (pyramid256batchnorm): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (pyramid256relu): LeakyReLU(negative_slope=0.2, inplace)
    (pyramid256-512conv): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (pyramid512batchnorm): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (pyramid512relu): LeakyReLU(negative_slope=0.2, inplace)
    (output-conv): Conv2d(512, 1, kernel_size=(4, 4), stride=(1, 1), bias=False)
    (output-sigmoid): Sigmoid()
  )
  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值