DCGAN Tutorials解读

DCGAN

URL

DCGAN Tutorial

DCGAN Github

概述

今天看了dcgan,链接为DCGAN Tutorial,感受到了这个网络框架的强大。gan网络是由两个部分组成,生成器和鉴别器。生成器用来生成假的数据,鉴别器用来识别假的的数据。生成器一直致力于把假的数据“变成”真的数据而鉴别器则是致力于把假的数据识别出来。所以当鉴别器只有 50 % 50\% 50%的概率判断生成器产生的数据为假时,该训练便成功。

GAN表达公式

我们可以定义 x x x为输入图像, D ( x ) D(x) D(x)为鉴别器网络。则当 x x x的值来自训练图像时, D ( x ) D(x) D(x)的值为1;当 x x x的值来自生成图像时, D ( x ) D(x) D(x)的值为0。 D ( x ) D(x) D(x)可以当作典型的二分类。对于生成器,我们可以定义 z z z为一个遵循标准正态分布的虚拟空间向量, G ( z ) G(z) G(z)表示生成器。 G ( z ) G(z) G(z)的目的是可以估计训练数据( P d a t a P_{data} Pdata)的分布,因此它能够根据数据( P g P_g Pg)分布估计假数据。

D ( G ( z ) ) D(G(z)) D(G(z))是一个估计生成器( G G G)是真实图像的概率模型。以下公式是GAN网络的loss function:
m i n m a x V G       D ( D , G ) = E x − P d a t a ( x ) [ l o g D ( x ) ] + E z − P z ( z ) [ l o g ( 1 − D ( G ( z ) ) ) ] \mathop{minmaxV}\limits_{G~~~~~D}(D,G) = E_{x-P_{data}(x)}[logD(x)]+E_{z-P_z(z)}[log(1-D(G(z)))] G     DminmaxV(D,G)=ExPdata(x)[logD(x)]+EzPz(z)[log(1D(G(z)))]

理论上讲,当 P g = P d a t a P_g=P_{data} Pg=Pdata时,判别器几乎不能判断输入是否为真。但是,GANs网络的收敛理论一直在研究,在现实模型中不能总是训练到目标点。

DCGAN

DCGAN是GAN网络的简单延申,但是它能分别对鉴别器和生成器进行卷积和反卷积。鉴别器由卷积层,池化层和LeakyReLU激活函数组成。当输入图像是 3 x 64 x 64 3x64x64 3x64x64的RGB图像时,输出是判断输入的图像是否为真的概率分布。生成器是由反卷积层,池化层和ReLU激活函数组成。它的输入为一个遵循标准正太分布的隐藏向量 z z z,输出为一个 3 x 64 x 64 3x64x64 3x64x64的RGB图像。反卷积可以使隐藏向量 z z z转换成一个和RGB图像有相同形状的卷。下面以CelebA为数据集,对CelebA做DCGAN网络训练。

包含库

from __future__ import print_function
#%matplotlib inline
import argparse
import os
import random
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from IPython.display import HTML

# Set random seed for reproducibility
manualSeed = 999
#manualSeed = random.randint(1, 10000) # use if you want new results
print("Random Seed: ", manualSeed)
random.seed(manualSeed)
torch.manual_seed(manualSeed)

这部分设置了一个随机种子,如果你想每次运行都得到一个新的结果,可以使用manualSeed = random.randint(1, 10000)这段代码。

该行代码输出为:

Out:
Random Seed:  999

输入

输入部分需要定义一些变量用于后续的训练。

d a t a r o o t dataroot dataroot数据集的路径,最好写绝对路径; w o r k e r s workers worke

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值