生死看淡,不服就GAN(六)----用DCGAN生成马的彩色图片

1. 首先我们需要的一组真实样本集来自cifar10,因此先制作一个读取cifar10的脚本。

"""
-------------------------------------------------------生死看淡,不服就GAN-------------------------------------------------------------------------
PROJECT: PreProcess
Author: Ephemeroptera
Date:2019-3-19
QQ:605686962

"""
import numpy as np
import TFRecordTools
import matplotlib.pyplot as plt

# 数据集归一化
def NORMALIZATION(data):
    from sklearn.preprocessing import MinMaxScaler
    minmax = MinMaxScaler()
    # 归一化
    data2 = minmax.fit_transform(data)
    return data2

# 获取cifar10 指定数据集
"""
kind:
    0:飞机 1:汽车 2:鸟 3:猫 4:鹿
    5:狗 6:狐狸 7:马 8:船 9:卡车
"""
def GetCifar10Data(CifarPath, kind):
    import pickle
    # 打开文件
    fo = open(CifarPath, 'rb')
    # 加载文件
    cifar10_dict  = pickle.load(fo, encoding='bytes')
    # cifar10_n标签集
    cifar10_label = cifar10_dict.get(b'labels')
    # cifar10_n数据集
    cifar10_data = cifar10_dict.get(b'data')
    # 提取指定类数据
    L = [label for label in cifar10_label if label == kind]
    C = [cifar10_data[label[0]] for label in enumerate(cifar10_label) if label[1] == kind]
    # 转化为np数组
    C = np.array(C)
    L = np.array(L)
    # 关闭文件
    fo.close()
    print('成功读取cifar10:%s --类别:%d 数据' % (CifarPath, kind))
    return C,L

# 读取全部数据集
def GetCifar10AllData(kind):
    C, L = GetCifar10Data(r'./cifar-10-batches-py/data_batch_1', kind)
    for i in range(2, 6):
        filename = './/cifar-10-batches-py//data_batch_' + str(i)
        # 读取batch_n文件
        data, label = GetCifar10Data(filename, kind)
        # 拼接
        C = np.concatenate((C, data))
        L = np.concatenate((L, label))
    return C,L

if __name__ == '__main__':
    # 取某一类
    C,L = GetCifar10AllData(7)
    # 图像归一化
    C = NORMALIZATION(C)
    # 显示
    imgs = C[-26:-1].reshape(-1,3,32,32).transpose((0,2,3,1))
    fig, axes = plt.subplots(figsize=(7, 7), nrows=5, ncols=5, sharex=True, sharey=True)
    for ax,img in zip(axes.flatten(),imgs):
        ax.xaxis.set_visible(False)
        ax.yaxis.set_visible(False)
        ax.imshow(img)
    plt.show()

    # 存入TFR
    TFRecordTools.SaveByTFRecord(C,L,r'./TFR/class7',5)



在该脚本中指定cifar路径和种类(马:7)再运行,读取cifar内容并以TFRecord格式保存,TFRecord是tensorflow便捷的数据集读取格式,上述依赖TFRecordTools脚本下载链接:

https://download.csdn.net/download/ephemeroptera/11088005

可视化如下:

在这里插入图片描述

如图所示,将马的数据集分割成5个TFR文件保存

2.DCGAN的搭建,代码如下(已给出详细注释)

"""
-------------------------------------------------------生死看淡,不服就GAN-------------------------------------------------------------------------
PROJECT: CIFAR10_DCGAN
Author: Ephemeroptera
Date:2019-3-19
QQ:605686962

"""

# 导入包
import numpy as np
import tensorflow as tf
import pickle
import TFRecordTools
import time

############################################### 设置参数 ####################################################################################

real_shape = [-1,32,32,3] # 真实样本尺寸
data_total = 5000 # 真实样本个数
batch_size = 64 # 批大小
noise_size = 128 # 噪声维度
max_iters = 10000 #的最大迭代次数
learning_rate = 0.0002 # 学习率
smooth = 0.1 # 标签平滑参数(label*(1-smooth))
beta1 = 0.4 #ADAM参数
CRITIC_NUM = 1 # 每次迭代判别器训练次数

############################################# 定义生成器和判别器 #############################################################################

# 定义生成器(32x32图片)
def Generator_DC_32x32(z, channel, is_train=True):
    """
    :param z: 噪声信号,tensor类型
    :param channnel: 生成图片的通道数
    :param is_train: 是否为训练状态,该参数主要用于作为batch_normalization方法中的参数使用(训练时候开启)
    """
    # 训练时生成器不允许复用
    with tf.variable_scope("generator", reuse=(not is_train)):

        # layer1: noise_dim --> 4*4*512 --> 4x4x512 -->BN+relu
        layer1 = tf.layers.dense(z, 4 * 4 * 512)
        layer1 = tf.reshape(layer1, [-1, 4, 4, 512])
        layer1 = tf.layers.batch_normalization(layer1, training=is_train,)
        layer1 = tf.nn.relu(layer1)
        # layer1 = tf.nn.dropout(layer1, keep_prob=0.8)# dropout

        # layer2: deconv(ks=3x3,s=2,padding=same):4x4x512 --> 8x8x256 --> BN+relu
        layer2 = tf.layers.conv2d_transpose(layer1, 256, 3, strides=2, padding='same',
                                            kernel_initializer=tf.random_normal_initializer(0, 0.02),
                                            bias_initializer=tf.random_normal_initializer(0, 0.02))
        layer2 = tf.layers.batch_normalization(layer2, training=is_train)
        layer2 = tf.nn.relu(layer2)
        # layer2 = tf.nn.dropout(layer2, keep_prob=0.8)# dropout

        # layer3: deconv(ks=3x3,s=2,padding=same):8x8x256 --> 16x16x128 --> BN+relu
        layer3 = tf.layers.conv2d_transpose(layer2, 128, 3, strides=2, padding='same',
                                            kernel_initializer=tf.random_normal_initializer(0, 0.02),
                                            bias_initializer=tf.random_normal_initializer(0, 0.02))
        layer3 = tf.layers.batch_normalization(layer3, training=is_train)
        layer3 = tf.nn.relu(layer3)
        # layer3 = tf.nn.dropout(layer3, keep_prob=0.8)# dropout

        # layer4: deconv(ks=3x3,s=2,padding=same):16x16x128 --> 32x32x64--> BN+relu
        layer4 = tf.layers.conv2d_transpose(layer3, 64, 3, strides=2, padding='same',
                                            kernel_initializer=tf.random_normal_initializer(0, 0.02),
                                            bias_initializer=tf.random_normal_initializer(0, 0.02))
        layer4 = tf.layers.batch_normalization(layer4, training=is_train)
        layer4 = tf.nn.relu(layer4)
        # layer4 = tf.nn.dropout(layer3, keep_prob=0.8)# dropout

        # logits: deconv(ks=3x3,s=2,padding=same):32x32x64 --> 32x32x3
        logits = tf.layers.conv2d_transpose(layer4, channel, 3, strides=1, padding='same',
                                            kernel_initializer=tf.random_normal_initializer(0, 0.02),
                                            bias_initializer=tf.random_normal_initializer(0, 0.02))
        # outputs
        outputs = tf.tanh(logits)

        return logits,outputs

# 定义判别器(32x32)
def Discriminator_DC_32x32(inputs_img, reuse=False, GAN = False,GP= False,alpha=0.2):
    """
    @param inputs_img: 输入图片,tensor类型
    @param reuse:判别器复用
    @param GP: 使用WGAN-GP时关闭BN
    @param alpha: Leaky ReLU系数
    """

    with tf.variable_scope("discriminator", reuse=reuse):

        # layer1: conv(ks=3x3,s=2,padding=same)+lrelu -->32x32x3 to 16x16x128
        layer1 = tf.layers.conv2d(inputs_img, 128, 3, strides=2, padding='same')
        if GP is False:
            layer1 = tf.layers.batch_normalization(layer1, training=True)
        layer1 = tf.nn.leaky_relu(layer1,alpha=alpha)
        # layer1 = tf.nn.dropout(layer1, keep_prob=0.8)

        # layer2: conv(ks=3x3,s=2,padding=same)+BN+lrelu -->16x16x128 to 8x8x256
        layer2 = tf.layers.conv2d(layer1, 256, 3, strides=2, padding='same')
        if GP is False:
            layer2 = tf.layers.batch_normalization(layer2, training=True)
        layer2 = tf.nn.leaky_relu(layer2, alpha=alpha)
        # layer2 = tf.nn.dropout(layer2, keep_prob=0.8)

        # layer3: conv(ks=3x3,s=2,padding=same)+BN+lrelu -->8x8x256 to 4x4x512
        layer3 = tf.layers.conv2d(layer2, 512, 3, strides=2, padding='same')
        if GP is False:
            layer3 = tf.layers.batch_normalization(layer3, training=True)
        layer3 = tf.nn.leaky_relu(layer3, alpha=alpha)
        layer3 = tf.reshape(layer3, [-1, 4*4* 512])
        # layer3 = tf.nn.dropout(layer2, keep_prob=0.8)

        # logits,output:
        logits = tf.layers.dense(layer3, 1)
        "WGAN:去除sigmoid"
        if GAN:
            outputs = None
        else:
            outputs = tf.sigmoid(logits)

        return logits, outputs

############################################## 定义计算图(网络) #######################################################

#----------------------输入----------------

inputs_real = tf.placeholder(tf.float32, [None, real_shape[1], real_shape[2], real_shape[3]], name='inputs_real') # 真实样本输入
inputs_noise = tf.placeholder(tf.float32, [None, noise_size], name='inputs_noise') # 生成样本输入

#-------------------生成和判别--------------
# 生成样本
_,g_outputs = Generator_DC_32x32(inputs_noise, real_shape[3], is_train=True) # 训练生成器
_,g_test = Generator_DC_32x32(inputs_noise, real_shape[3], is_train=False) # 测试生成器
# 判别样本
d_logits_real, _ = Discriminator_DC_32x32(inputs_real) #识别真样本
d_logits_fake, _ = Discriminator_DC_32x32(g_outputs, reuse=True) ##识别假样本

#------------定义原始GAN的损失函数--------------
# 生成器loss
g_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=d_logits_fake,
                                                                labels=tf.ones_like(d_logits_fake) * (1 - smooth)))
# 判别器loss_real
d_loss_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=d_logits_real,
                                                                     labels=tf.ones_like(d_logits_real) * (1 - smooth)))
# 判别器loss_fake
d_loss_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=d_logits_fake,labels=tf.zeros_like(d_logits_fake)))

# 判别器loss
d_loss = tf.add(d_loss_real, d_loss_fake)

#-------------------训练模型-----------------
# 分别获取生成器和判别器的变量空间
train_vars = tf.trainable_variables()
g_vars = [var for var in train_vars if var.name.startswith("generator")]
d_vars = [var for var in train_vars if var.name.startswith("discriminator")]

# Optimizer
with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)):# 保证BN白化先完成
    g_train_opt = tf.train.AdamOptimizer(learning_rate, beta1=beta1).minimize(g_loss, var_list=g_vars) # 训练生成器
    d_train_opt = tf.train.AdamOptimizer(learning_rate, beta1=beta1).minimize(d_loss, var_list=d_vars) # 训练判别器

############################################# 调用TFRecord读取数据 #####################################################

# 读取TFR,不打乱文件顺序,指定数据类型,开启多线程
[data,label] = TFRecordTools.ReadFromTFRecord(sameName= r'.\TFR\class7-*',isShuffle= False,datatype= tf.float64,
                                labeltype= tf.int32,isMultithreading= True)
# 批量处理,送入队列数据,指定数据大小,打乱数据项,设置批次大小64
[data_batch,label_batch] = TFRecordTools.DataBatch(data,label,dataSize= 32*32*3,labelSize= 1,
                                                   isShuffle= True,batchSize= 64)

############################################### 迭代 ###################################################################

# 存储训练过程中生成日志
GenLog = []
# 存储loss
losses = []
# 保存生成器变量(仅保存生成器模型,保存最近5个)
saver = tf.train.Saver(var_list=[var for var in tf.trainable_variables()
                                 if var.name.startswith("generator")],max_to_keep=5)
# 定义批预处理
def batch_preprocess(data_batch):
    # 提取批数据
    batch = sess.run(data_batch)
    # 整理成RGB(Nx32x32x3)
    batch_images = np.reshape(batch, [-1, 3, 32, 32]).transpose((0, 2, 3, 1))  # (-1,32,32,3)
    # scale to -1, 1
    batch_images = batch_images * 2 - 1
    return  batch_images

# 生成相关目录保存生成信息
def GEN_DIR():
    import os
    if not os.path.isdir('ckpt'):
        print('文件夹ckpt未创建,现在在当前目录下创建..')
        os.mkdir('ckpt')
    if not os.path.isdir('trainLog'):
        print('文件夹ckpt未创建,现在在当前目录下创建..')
        os.mkdir('trainLog')

# 开启会话
with tf.Session() as sess:
    # 生成相关目录
    GEN_DIR()

    # 初始化变量
    init = (tf.global_variables_initializer(), tf.local_variables_initializer())
    sess.run(init)

    # 开启协调器
    coord = tf.train.Coordinator()
    # 启动线程
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)

    time_start = time.time() # 开始计时
    for steps in range(max_iters):
        steps += 1

        # 判别器重复训练设置
        if steps < 25 or steps % 500 == 0:
            critic_num = CRITIC_NUM
        else:
            critic_num = CRITIC_NUM

        batch_noise = np.random.normal(size=(batch_size, noise_size))  # 高斯噪声
        batch_images = batch_preprocess(data_batch)# 真实图像样本

        # 重复训练判别器
        for i in range(CRITIC_NUM):
            _ = sess.run(d_train_opt, feed_dict={inputs_real: batch_images,
                                                 inputs_noise: batch_noise})
        # 训练生成器
        _ = sess.run(g_train_opt, feed_dict={inputs_real: batch_images,
                                             inputs_noise: batch_noise})

        #  记录训练信息
        if steps % 5 == 1:
            # (1)记录损失函数
            train_loss_d = d_loss.eval({inputs_real: batch_images,
                                        inputs_noise: batch_noise})
            train_loss_g = g_loss.eval({inputs_real: batch_images,
                                        inputs_noise: batch_noise})
            losses.append([train_loss_d, train_loss_g,steps])

            # (2)记录生成样本
            batch_noise = np.random.normal(size=(batch_size, noise_size))
            gen_samples = sess.run(g_test, feed_dict={inputs_noise: batch_noise})
            genLog = (gen_samples[0:11] + 1) / 2  # 恢复颜色空间(取10张)
            GenLog.append(genLog)

            # (3)打印信息
            print('step {}...'.format(steps),
                  "Discriminator Loss: {:.4f}...".format(train_loss_d),
                  "Generator Loss: {:.4f}...".format(train_loss_g))

        # (4)保存生成模型
        if steps % 300 ==0:
            saver.save(sess, './ckpt/generator.ckpt', global_step=steps)

    # 关闭线程
    coord.request_stop()
    coord.join(threads)
    
#计时结束:
time_end = time.time()
print('迭代结束,耗时:%.2f秒'%(time_end-time_start))

# 保存信息
#  保存loss记录
with open('./trainLog/loss_variation.loss', 'wb') as l:
    losses = np.array(losses)
    pickle.dump(losses,l)
    print('保存loss信息..')

# 保存生成日志
with open('./trainLog/GenLog.log', 'wb') as g:
    pickle.dump(GenLog, g)
    print('保存GenLog信息..')

经过10000次迭代

训练过程中保存了G的生成日志,G和D的损失函数(trainLog目录下),以及G的模型(ckpt目录下)

3.查看生成日志和测试生成器

"""
-------------------------------------------------------生死看淡,不服就GAN-------------------------------------------------------------------------
PROJECT: Show
Author: Ephemeroptera
Date:2019-3-19
QQ:605686962

"""
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
import pickle

def ImgShow(IMG,index,nums):
    """
    :param IMG: 生成样本集合
    :param index: 查看某一次训练的下标(列表格式)
    :param nums:  显示某一次生成样本的个数
    """
    # 定义坐标系
    fig, axes = plt.subplots(figsize=( nums+2,len(index)+2), nrows=len(index), ncols=nums, sharey=True, sharex=True)
    if len(index) == 1:
        for ax,img in zip(axes.flatten(),IMG[index[0]]):
            ax.xaxis.set_visible(False)
            ax.yaxis.set_visible(False)
            ax.imshow(img)
    else:
        for ax_row, idx in zip(axes, index):
            img_row = IMG[idx][0:nums]
            for img,ax in zip(img_row,ax_row):
                ax.xaxis.set_visible(False)
                ax.yaxis.set_visible(False)
                ax.imshow(img)
    fig.tight_layout(pad=0)
    plt.show()


if __name__ == '__main__':

##################################################### 生成日志 ###########################################################################
    ### (1) 显示最后一批
    with open('./trainLog/GenLog.log', 'rb') as f:
        #读取生成记录
        GenLog = pickle.load(f)
        GenLog = np.array(GenLog)
       # 显示最后一次
        ImgShow(GenLog,[-1],10)

    ### (2) 显示过程
        # 均匀采样10次
        epoch_idx0 = np.linspace(1, GenLog.shape[0] - 1, 20)
        epoch_idx = [int(i) for i in epoch_idx0]
        ImgShow(GenLog,epoch_idx,10)

    ############################################### 显示损失函数 ##################################################################################

    with open(r'./trainLog/loss_variation.loss','rb') as l:
        losses = pickle.load(l)
        fig, ax = plt.subplots(figsize=(20, 7))
        plt.plot(losses.T[2],losses.T[0], label='Discriminator  Loss')
        plt.plot(losses.T[2],losses.T[1], label='Generator Loss')
        plt.title("Training Losses")
        plt.legend()
        plt.show()


    ###############################################  验证生成器 ###################################################################################
    with tf.Session() as sess:

        meta_graph = tf.train.import_meta_graph('./ckpt/generator.ckpt-9000.meta')# 加载模型
        meta_graph.restore(sess,tf.train.latest_checkpoint('./ckpt'))# 加载最近一次数据
        graph = tf.get_default_graph()
        inputs_noise = graph.get_tensor_by_name("inputs_noise:0")# 获取输入占位符
        d_outputs_fake = graph.get_tensor_by_name("generator/Tanh:0")

        sample_noise= np.random.normal(size=(10, 128))# 生成输入噪声
        gen_samples = sess.run(d_outputs_fake,feed_dict={inputs_noise: sample_noise})# 验证模型
        gen_samples = [(gen_samples[0:11]+1)/2] # 恢复颜色空间
        ImgShow(gen_samples, [0], 10)



最后一次生成样本

训练过程生成日志

损失函数

验证生成器

  • 3
    点赞
  • 13
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Ephemeroptera

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值