条件深度卷积生成对抗网络——生成MNIST图片

本文介绍了使用条件深度卷积生成对抗网络(cDCGAN)生成MNIST手写数字的实践过程。文章详细阐述了算法描述、流程、网络结构,并分享了训练环境和代码实现,最终展示了一些生成结果。
摘要由CSDN通过智能技术生成

欢迎来我的博客 http://www.blackblog.tech,我的简书 https://www.jianshu.com/u/55a1bc4688c6

前几日,学校期末作业要求我们使用机器学习的方法解决一个实际问题,思考了很久,尝试做了很多选题,最终决定做一个cDCGAN,即条件深度卷积生成对抗网络。
为什么做这个选题呢?
生成对抗网络这几年实在是火爆,图片上色,视频去马赛克,包括英伟达最近展出的白马变棕马,白天变黑夜,都是使用生成对抗网络实现的。
2014年”Generative Adversarial Nets”这篇论文中所提到的生成对抗网络是一个无监督的生成对抗网络,且没有使用卷积与反卷积操作。
今天我们以MNIST手写集为数据集,使用tensorflow实现cDCGAN(条件深度卷积生成对抗网络)

算法描述

生成对抗网络(Generative Adversarial Nets)启发自博弈论中的两人零和博弈,GAN模型中的两位博弈方分别有生成网络(Generator)与判别网络(Discriminator)充当。当生成网络G捕捉到样本数据分布,用服从某一分布的噪声z生成一个类似真实训练数据的样本,与真实样本越接近越好;判别网络D一般是一个二分类模型,在本文中D是一个多分类器,用于估计一个样本来自于真实数据的概率,如果样本来自于真实数据,则D输出大概率,否则输出小概率。本文中,判别网络需要在此基础上实现分类功能。

在训练的过程中,需要固定一方,更新另一方的网络状态,如此交替进行。在整个训练的过程中,双方都极力优化自己的网络,从而形成竞争对抗,知道双方达到一个动态的平衡。此时生成网络训练出来的数据与真实数据的分布几乎相同,判别网络也无法再判断出真伪。
本文中生成对抗网络主要分为两部分,生成网络(Generator)与判别网络(Discriminator)。向生成网络内输入噪声,通过多次反卷积的方式得到一个28x28x1的图像作为X_fake,此时将真实的图像X_real与生成器生成的X_fake放入判别网络,判别网络使用多次卷积与Sigmoid函数并通过交叉熵函数计算出判别网络的损失函数D_loss,通过判别网络的损失函数D_loss计算得到生成网络损失函数G_loss。使用G_loss与D_loss对生成网络与判别网络进行参数调整。

算法流程

1.输入噪声z
2.通过生成网络G得到X_fake=G(z)
3.从数据集中获取真实数据X_real
4.通过判别网络D计算D(real logits)=D(X_real)
5.通过判别网络D计算D(fake logits)=D(X_fake)
6.使用交叉熵函数做损失函数根据D(real logits)计算D(loss real)
7.使用交叉熵函数做损失函数根据D(fake logits)计算D(loss fake)
8.计算判别网络损失函数D_loss=D(loss real)+ D_(loss fake)
9.使用交叉熵函数做损失函数计算生成网络损失函数G_loss
10.使用D_loss对判别网络进行参数调整,使用G_loss对生成网络参数进行调整

网络结构

生成网络

生成网络

判别网络

判别网络

数据集

MNIST…..
就不多说啥了

训练环境

系统:Windows 10
框架:tensorflow 1.2
CPU:Intel core i5-4210H
GPU:Nvidia GTX 960M 4G(买不起显卡……..)

上代码!

一些常量的定义,包括学校率,batch_size,保存的路径等等

import os, time, random,itertools
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import cv2
# 保存图片
dirpath = 'GAN/'
model = 'GAN_MINIST'
if not os.path.isdir(dirpath):
    os.mkdir(dirpath)
if not os.path.isdir(dirpath + 'FakeImg'):
    os.mkdir(dirpath + 'FakeImg')
# 初始化
IMAGE_SIZE = 28
onehot = np.eye(10)
noise_ = np.random.normal(0, 1, (10, 1, 1, 100))
fixed_noise_ = noise_
fixed_label_ = np.zeros((10, 1))
#用于最后显示十组图像
评论 9
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值