Tensorflow入门八-生成式对抗网络GAN(百度云源码数据资源)

上一篇:Tensorflow入门七-迁移学习实现VGG16微调猫狗分类(迁移学习源1000分类权重文件百度云可直接再学习,线程)https://blog.csdn.net/qq_36187544/article/details/89885247

下一篇:Tensorflow入门九-keras应用计算泰坦尼克生存率(百度云源码数据资源)https://blog.csdn.net/qq_36187544/article/details/90040433


记录这篇文章,一来可以学习下GAN概论,二来代码完整可用,需要的话可就此代码改进新模型

目录

资源说明

GAN与CGAN

源码及运行效果:


资源说明

链接:https://pan.baidu.com/s/16fAdMkbXq-C9ZNU5tHnBIA 
提取码:hndh 

GAN_TRAIN为训练一代后的文件,附带结果图和模型文件;GAN_RAW为初始文件,可拿着就可以使用,较小

checkpoint保存点文件,GAN_RAW.rar里为空,GAN_TRAIN.rar里有一次迭代的模型文件
data数据集文件
logs日志文件,GAN_RAW.rar里为空,GAN_TRAIN.rar里有一次迭代的日志文件
results结果文件,GAN_RAW.rar里为空,GAN_TRAIN.rar里有一次迭代的输出图像文件
CGAN.py条件生成式对抗网络
GAN.py生成式对抗网络
main.py运行主文件
utils.py和ops.py配置以及工具文件

GAN与CGAN

生成式模型的意义:1.处理高维数据和复杂概率分布能力的检测。2.面临缺乏或缺失数据时,使用生成模型补足,3.可输出多模态结果

原理:

核心思想来自于博弈论的纳什均衡,

简言之,即生成网络生成图片,判别网络判别图片是否为真,生成网络的目的尽可能学习真实数据分布,让判别网络无法区分,而判别网络则是二分类网络,尽可能区分图片是否为真。相当于二者进行博弈,然后不断提升水平

更专业的原理图(有需求再仔细研究吧):

fashion-mnist数据集共10类别(T恤,帽子等)类比于MNIST手写数据集

利用GAN即生成式对抗网络,可生成如下图片:

而CGAN即条件生成式对抗网络,可实现条件式生成,比如分类别生成(所以生成器需要生成与条件匹配的,判别器需要判别图像且符合条件的):


源码及运行效果:

GAN.py:

#-*- coding: utf-8 -*-
"""
Most codes from 
https://github.com/carpedm20/DCGAN-tensorflow
"""
from __future__ import division
import time

from GAN_PACKAGE.ops import *
from GAN_PACKAGE.utils import *

class GAN(object):
    model_name = "GAN"      #模型名称,checkpoint文件夹名称
    
    """ 对实例的属性进行初始化 """
    def __init__(self, sess, epoch, batch_size, z_dim, dataset_name, checkpoint_dir, result_dir, log_dir):
        self.sess = sess
        self.dataset_name = dataset_name
        self.checkpoint_dir = checkpoint_dir
        self.result_dir = result_dir
        self.log_dir = log_dir
        self.epoch = epoch
        self.batch_size = batch_size
        
        # 参数值
        self.input_height = 28
        self.input_width = 28
        self.output_height = 28
        self.output_width = 28
        self.z_dim = z_dim         # 噪声矢量的维度
        self.c_dim = 1             # 由于fashion是灰度图,因此维度为1
        self.learning_rate = 0.0002
        self.beta1 = 0.5
        self.sample_num = 64  # 设置保存生成图片的数量

        # 载入数据
        self.data_X, self.data_y = load_mnist(self.dataset_name)

        # 每一个epoch中batch数量
        self.num_batches = len(self.data_X) // self.batch_size
    
    """ 搭建判别器 """
    def discriminator(self, x, is_training=True, reuse=False):
        with tf.variable_scope("discriminator", reuse=reuse):
            net = lrelu(conv2d(x, 64, 4, 4, 2, 2, name='d_conv1'))
            net = lrelu(bn(conv2d(net, 128, 4, 4, 2, 2, name='d_conv2'), is_training=is_training, scope='d_bn2'))
            net = tf.reshape(net, [self.batch_size, -1])
            net = lrelu(bn(linear(net, 1024, scope='d_fc3'), is_training=is_training, scope='d_bn3'))
            out_logit = linear(net, 1, scope='d_fc4')
            out = tf.nn.sigmoid(out_logit)

            return out, out_logit, net
            
    """ 搭建生成器 """
    def generator(self, z, is_training=True, reuse=False):
        with tf.variable_scope("generator", reuse=reuse):
            net = tf.nn.relu(bn(linear(z, 1024, scope='g_fc1'), is_training=is_training, scope='g_bn1'))
            net = tf.nn.relu(bn(linear(net, 128 * 7 * 7, scope='g_fc2'), is_training=is_training, scope='g_bn2'))
            net = tf.reshape(net, [self.batch_size, 7, 7, 128])
            net = tf.nn.relu(
                bn(deconv2d(net, [self.batch_size, 14, 14, 64], 4, 4, 2, 2, name='g_dc3'), is_training=is_training,
                   scope='g_bn3'))
            out = tf.nn.sigmoid(deconv2d(net, [self.batch_size, 28, 28, 1], 4, 4, 2, 2, name='g_dc4'))

            return out

    """ 构建模型 """
    def build_model(self):
        image_dims = [self.input_height, self.input_width, self.c_dim]
        bs = self.batch_size

        """ 输入 """
        # 图像
        self.inputs = tf.placeholder(tf.float32, [bs] + image_dims, name='real_images')
        # 噪声矢量
        self.z = tf.placeholder(tf.float32, [bs, self.z_dim], name='z')

        """ 损失函数 """
        # 判别器对于真实图像的输出
        D_real, D_real_logits, _ = self.discriminator(self.inputs, is_training=True, reuse=False)
        # 判别器对于生成图像的输出
        G = self.generator(self.z, is_training=True, reuse=False)
        D_fake, D_fake_logits, _ = self.discriminator(G, is_training=True, reuse=True)
        # 判别器的损失函数
        d_loss_real = tf.reduce_mean(
            tf.nn.sigmoid_cross_entropy_with_logits(logits=D_real_logits, labels=tf.ones_like(D_real)))
        d_loss_fake = tf.reduce_mean(
            tf.nn.sigmoid_cross_entropy_with_logits(logits=D_fake_logits, labels=tf.zeros_like(D_fake)))

        self.d_loss = d_loss_real + d_loss_fake
        
        # 生成器的损失函数
        self.g_loss = tf.reduce_mean(
            tf.nn.sigmoid_cross_entropy_with_logits(logits=D_fake_logits, labels=tf.ones_like(D_fake)))
        
        # 可训练变量
        t_vars = tf.trainable_variables()
        d_vars = [var for var in t_vars if 'd_' in var.name]
        g_vars = [var for var in t_vars if 'g_' in var.name]

        # 优化器
        with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
            self.d_optim = tf.train.AdamOptimizer(self.learning_rate, beta1=self.beta1) \
                      .minimize(self.d_loss, var_list=d_vars)
            self.g_optim = tf.train.AdamOptimizer(self.learning_rate*5, beta1=self.beta1) \
                      .minimize(self.g_loss, var_list=g_vars)
        # 生成图像
        self.fake_images = self.generator(self.z, is_training=False, reuse=True)

        d_loss_real_sum = tf.summary.scalar("d_loss_real", d_loss_real)
        d_loss_fake_sum = tf.summary.scalar("d_loss_fake", d_loss_fake)
        d_loss_sum = tf.summary.scalar("d_loss", self.d_loss)
        g_loss_sum = tf.summary.scalar("g_loss", self.g_loss)
        self.g_sum = tf.summary.merge([d_loss_fake_sum, g_loss_sum])
        self.d_sum = tf.summary.merge([d_loss_real_sum, d_loss_sum])
   
    """ 执行训练 """
    def train(self):
        # 变量的初始化
        tf.global_variables_initializer().run()
        # 图(graph)的输入
        self.sample_z = np.random.uniform(-1, 1, size=(self.batch_size , self.z_dim))
        self.saver = tf.train.Saver()
        self.writer = tf.summary.FileWriter(self.log_dir + '/' + self.model_name, self.sess.graph)
        # 载入checkpoint
        could_load, checkpoint_counter = self.load(self.checkpoint_dir)
        if could_load:
            start_epoch = (int)(checkpoint_counter / self.num_batches)
            start_batch_id = checkpoint_counter - start_epoch * self.num_batches
            counter = checkpoint_counter
            print(" [*] Load SUCCESS")
        else:
            start_epoch = 0
            start_batch_id = 0
            counter = 1
            print(" [!] Load failed...")
        # epoch迭代
        start_time = time.time()
        for epoch in range(start_epoch, self.epoch):
            # 获取批量数据
            for idx in range(start_batch_id, self.num_batches):
                batch_images = self.data_X[idx*self.batch_size:(idx+1)*self.batch_size]
                batch_z = np.random.uniform(-1, 1, [self.batch_size, self.z_dim]).astype(np.float32)
                # 更新判别器
                _, summary_str, d_loss = self.sess.run([self.d_optim, self.d_sum, self.d_loss],
                                               feed_dict={self.inputs: batch_images, self.z: batch_z})
                self.writer.add_summary(summary_str, counter)
                # 更新生成器
                _, summary_str, g_loss = self.sess.run([self.g_optim, self.g_sum, self.g_loss], feed_dict={self.z: batch_z})
                self.writer.add_summary(summary_str, counter)
                # 显示训练状态
                counter += 1
                print("Epoch: [%2d] [%4d/%4d] time: %4.4f, d_loss: %.8f, g_loss: %.8f" \
                      % (epoch, idx, self.num_batches, time.time() - start_time, d_loss, g_loss))
                # 每50步保存训练结果
                if np.mod(counter, 50) == 0:
                    samples = self.sess.run(self.fake_images, feed_dict={self.z: self.sample_z})
                    tot_num_samples = min(self.sample_num, self.batch_size)
                    manifold_h = int(np.floor(np.sqrt(tot_num_samples)))
                    manifold_w = int(np.floor(np.sqrt(tot_num_samples)))
                    save_images(samples[:manifold_h * manifold_w, :, :, :], [manifold_h, manifold_w],
                                './' + check_folder(self.result_dir + '/' + self.model_dir) + '/' + self.model_name + '_train_{:02d}_{:04d}.png'.format(
                                    epoch, idx))
            start_batch_id = 0
            # 保存模型
            self.save(self.checkpoint_dir, counter)
            # 当前结果的可视化
            self.visualize_results(epoch)
        # 保存最终模型
        self.save(self.checkpoint_dir, counter)

    """ 定义功能函数 """
    def visualize_results(self, epoch):
        tot_num_samples = min(self.sample_num, self.batch_size)
        image_frame_dim = int(np.floor(np.sqrt(tot_num_samples)))

        z_sample = np.random.uniform(-1, 1, size=(self.batch_size, self.z_dim))
        samples = self.sess.run(self.fake_images, feed_dict={self.z: z_sample})

        save_images(samples[:image_frame_dim * image_frame_dim, :, :, :], [image_frame_dim, image_frame_dim],
                    check_folder(self.result_dir + '/' + self.model_dir) + '/' + self.model_name + '_epoch%03d' % epoch + '_test_all_classes.png')

    @property
    def model_dir(self):
        return "{}_{}_{}_{}".format(
            self.model_name, self.dataset_name,
            self.batch_size, self.z_dim)

    def save(self, checkpoint_dir, step):
        checkpoint_dir = os.path.join(checkpoint_dir, self.model_dir, self.model_name)

        if not os.path.exists(checkpoint_dir):
            os.makedirs(checkpoint_dir)

        self.saver.save(self.sess,os.path.join(checkpoint_dir, self.model_name+'.model'), global_step=step)

    def load(self, checkpoint_dir):
        import re
        print(" [*] Reading checkpoints...")
        checkpoint_dir = os.path.join(checkpoint_dir, self.model_dir, self.model_name)

        ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
        if ckpt and ckpt.model_checkpoint_path:
            ckpt_name = os.path.basename(ckpt.model_checkpoint_path)
            self.saver.restore(self.sess, os.path.join(checkpoint_dir, ckpt_name))
            counter = int(next(re.finditer("(\d+)(?!.*\d)",ckpt_name)).group(0))
            print(" [*] Success to read {}".format(ckpt_name))
            return True, counter
        else:
            print(" [*] Failed to find a checkpoint")
            return False, 0

迭代一次之后:

可以在result中看到结果图片

可以看到,前几次图片迭代时图片还基本为空,逐渐就更加自然了。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值