tensorflow之学习GAN生成对抗网络

该博客介绍了GAN(生成对抗网络)的基本原理,包括生成器和判别器的运作方式,以及它们如何通过博弈达到生成逼真图像的平衡。在训练过程中,详细展示了数据预处理、模型构建、损失函数、优化器的设定以及训练过程。同时,通过代码展示了如何在MNIST数据集上训练GAN,随着训练次数增加,生成的图像逐渐接近真实手写数字。
摘要由CSDN通过智能技术生成

前言

一、GAN原理

GAN主要包括两部分,生成器generator和判别器discriminator。

  • 生成器:用于学习真实图像分布进而让自身生成的图像更加真实,得以骗过判别器。
  • 判别器:对接受的图片进行真假判别。

在这里插入图片描述

我们可以看到,右侧蓝色框出的区域,其实就是CNN神经网络,通过卷积全连接对图片进行判别,是图片or不是图片。
我们再看左侧,下方绿色区域是个生成器,生成器制造假的样本,或者说制造接近于真实的样本。随机向量使用多个反卷积层生成一张图像,然后将图像输入给判别器,同时这个判别器也接收真实样本的输入。

判别器对于真实样本的输入,输出为1,制造出的假样本输出为0。
然后对这个网络进行训练,真实为1,假的为0,训练目标为让生成器生出来的图片越来越接近于1,接近于真实分布,使其骗过生成器。

最后的结果为生成的图像越来越接近于真实的图像,而判别器越来越精确,最后对于真假样本,稳定于概率为0.5。

在训练过程中,生成器努力使生成图像更加真实,而判别器努力识别出图像真假,这个过程相当于一个二人博弈,随着时间推移,生成器和判别器不断进行对抗。
最终两个网络达到一个动态均衡:生成器生成的图像接近于真实图像分布,而判别器识别不出真假图像,对于给定图像的预测为真的概率基本接近0.5。

二、应用领域

  • 图像生成
  • 图像增强
  • 风格化
  • 艺术图像创造

开始代码学习

1.导入相关库及整理数据集

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import matplotlib.pyplot as plt
%matplotlib inline
import numpy as np
import glob
import os
tf.__version__
'2.5.0'
(train_images, train_labels),(_, _) = tf.keras.datasets.mnist.load_data() # _表示占位符
train_images.shape
60000, 28, 28)
train_images.dtype
dtype('uint8')
train_images = train_images.reshape(train_images.shape[0], 28, 28, 1).astype('float32') # 使其变为四维
# 归一化处理,使其在(-1, 1)之间
train_images = (train_images - 127.5) / 127.5
# 一些参数
BATCH_SIZE = 256
BUFFER_SIZE = 60000
# 创建数据集
datasets = tf.data.Dataset.from_tensor_slices(train_images)
# 全部范围乱序取出BATCH_SIZE数据
datasets = datasets.shuffle(BUFFER_SIZE).batch
  • 4
    点赞
  • 15
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值