12-GAN-使用GAN生成相似图片-数据增强

一、前言
本篇将使用GAN(生成对抗网络)生成一类花卉图片的训练样本,也就是说我们现在拥有1000张樱花的图片,但是数据集的量太少导致分类模型的判别能力较弱,而且容易导致过拟合等问题,所以这里我们需要增加训练样本,但是实地采集又太麻烦并且爬虫得到的数据冗杂数据又太多,这个时候该怎么办呢?GAN的双网络并行训练的机制解决了这个问题。我们此案例就是依靠GAN的特点来生成更多的图片来扩充整个数据集合。
二、概念概述
GAN由generate(生成式模型)和discriminator(判别式模型)两部分构成。
1.generate:主要是从训练数据中产生相同分布的samples,对于输入x,类别标签y,在生成式模型中估计其联合概率分布(两个及以上随机变量组成的随机向量的概率分布)。
2.discriminator:判断输入是真实数据还是generate生成的数据,即估计样本属于某类的条件概率分布。它采用传统的监督学习方法。
两个模型结合后,经过大量次数的迭代训练会使generate尽可能模拟出以假乱真的样本,而discriminator会有更精确的鉴别真伪数据的能力,最终会达到博弈论中的整体纳什均衡,即discriminator对于generate的数据鉴别结果为正确率和错误率各占50%。

三、具体实现

# _*_ coding:utf-8 _*_


"""
    Created by 南枫木木 2018/9/7
"""
import tensorflow as tf
import numpy as np
import os
import math
import PIL.Image as Image
import matplotlib.pyplot as plt
#图片保存api
from skimage.io import imsave
import pylab

batch_size=3
#train_epochs=10000
max_epoch=500
checkpoint_dir='model/'
#读取并标准化图片信息,返回图片tensor列表
def images_transform(image_dir_path,size=(64,64),channels=3):
    '''
    Picture conversion
    :param image_path:
    :param size:
    :param channels:
    :return:
    '''
    list_images=os.listdir(image_dir_path)
    list_image_tensors=[]
    os.chdir(image_dir_path)
    for i in range(len(list_images)):
        #file_contents=tf.read_file(list_images[i])
        #读取图片信息
        print(list_images[i])
        file_contents=tf.gfile.FastGFile(list_images[i],'rb').read()

        #将二进制图片信息解码
        try:
            image_tensor=tf.image.decode_jpeg(file_contents,channels=channels)
        except BaseException :
            image_tensor=tf.image.decode_image(file_contents,channels=channels)
        #print(image_tensor)
        #print(list_images[i])
        #将图片进行标准化
        image_tensor=tf.image.resize_images(image_tensor,size=size,method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
        list_image_tensors.append(image_tensor)
    #返回转化后的tensor列表
    return list_image_tensors

#判断是否存在文件夹
def test_folder_is_exist(text_folder):
    if os.path.exists(text_folder):
        print(os.listdir('./'))
    else:
        os.mkdir(text_folder)
        print(os.listdir('./'))

#保存图片数据
def save_image_message(save_tensor,text_folder,classify_name,data_path='data'):
    #进入DATA的同级目录
    print(os.getcwd())
    os.chdir("../"+text_folder)
    #将tensor reshape为二维数组(1000,150528)
    save_tensor=tf.reshape(save_tensor,shape=(-1,12288))
    #将tensor转换为numpy数组
    save_array=save_tensor.eval()
    #保存array为text文件
    np.savetxt(classify_name+'.txt',save_array)

#判断是否保存过图片数据
def judge_message_exist(text_folder):
    judge_=os.path.exists(text_folder)
    return judge_

#定义参数初始化方式
def get_variable(name,shape,dtype=tf.float32,initializer=tf.random_normal_initializer(0,0.1)):
    return tf.get_variable(name,shape,dtype=dtype,initializer=initializer)

#定义生成模型
def set_generate_model(Noise_Data):
    """
    一维服从正态分布的噪音数据通过反卷积生成原始图片信息
    :param Noise_Data: 一维噪音数据
    :return: 图片信息
    (batch_size, 1, 1, 1024)
    (batch_size, 2, 2, 256)
    (batch_size, 2, 2, 256)
    (batch_size, 4, 4, 128)
    (batch_size, 4, 4, 128)
    (batch_size, 8, 8, 32)
    (batch_size, 8, 8, 32)
    (batch_size, 16, 16, 16)
    (batch_size, 16, 16, 16)
    (batch_size, 32, 32, 8)
    (batch_size, 32, 32, 8)
    (batch_size, 64, 64, 3)
    """
    fc1=1024 #1*1*1024
    Noise_size = Noise_Data.get_shape()[1]
    Noise_Data = tf.reshape(Noise_Data, [batch_size, -1])
    #fc层参数
    w1=get_variable('w1',shape=[Noise_size,fc1])
    b1=get_variable('b1',shape=[fc1])
    w2=get_variable('w2',[3,3,256,1024])
    b2=get_variable('b2',[256])
    w3=get_variable('w3',[3,3,128,256])
    b3=get_variable('b3',[128])
    w4=get_variable('w4',[3,3,32,128])
    b4=get_variable('b4',[32])
    w5=get_variable('w5',[3,3,16,32])
    b5=get_variable('b5',[16])
    w6=get_variable('w6',[3,3,8,16])
    b6=get_variable('b6',[8])
    w7=get_variable('w7', [3, 3, 3, 8])
    b7=get_variable('b7', [3])
    with tf.variable_scope('G_tans_fc1'):
        net=tf.matmul(Noise_Data,w1)+b1
        net=tf.reshape(net,[-1,1,1,fc1])
    with tf.variable_scope('G_tans_Conv2'):
        print(net.get_shape())
        net=tf.nn.conv2d_transpose(net,filter=w2,output_shape=[batch_size,2,2,256],strides=[1,2,2,1],padding='SAME')
        net=tf.nn.bias_add(net,b2)
        print(net.get_shape())
        net=tf.nn.leaky_relu(net)
    with tf.variable_scope('G_tans_Conv3'):
        print(net.get_shape())
        net=tf.nn.conv2d_transpose(net,filter=w3,output_shape=[batch_size,4,4,128],strides=[1,2,2,1],padding='SAME')
        net=tf.nn.bias_add(net,b3)
        print(net.get_shape())
        net=tf.nn.leaky_relu(net)
    with tf.variable_scope('G_tans_Conv4'):
        print(net.get_shape())
        net=tf.nn.conv2d_transpose(net,filter=w4,output_shape=[batch_size,8,8,32],strides=[1,2,2,1],padding='SAME')
        net=tf.nn.bias_add(net,b4)
        print(net.get_shape())
        net=tf.nn.leaky_relu(net)
    with tf.variable_scope('G_tans_Conv5'):
        print(net.get_shape())
        net=tf.nn.conv2d_transpose(net,filter=w5,output_shape=[batch_size,16,16,16],strides=[1,2,2,1],padding='SAME')
        net=tf.nn.bias_add(net,b5)
        print(net.get_shape())
        net=tf.nn.leaky_relu(net)
    with tf.variable_scope('G_tans_Conv6'):
        print(net.get_shape())
        net=tf.nn.conv2d_transpose(net,filter=w6,output_shape=[batch_size,32,32,8],strides=[1,2,2,1],padding='SAME')
        net=tf.nn.bias_add(net,b6)
        print(net.get_shape())
        net=tf.nn.leaky_relu(net)
    with tf.variable_scope('G_tans_Conv7'):
        print(net.get_shape())
        net = tf.nn.conv2d_transpose(net, filter=w7, output_shape=[batch_size, 64, 64, 3],
                                     strides=[1, 2, 2, 1], padding='SAME')
        net = tf.nn.bias_add(net, b7)
        print(net.get_shape())
        net = tf.nn.tanh(net)
    return net,[w1,b1,w2,b2,w3,b3,w4,b4,w5,b5,w6,b6,w7,b7]

#定义判别模型
def set_discriminant_model(X_data,X_generate):
    w_f_1=get_variable('w_f_1',[3,3,3,16])
    b_f_1=get_variable('b_f_1',[16])
    w_f_2=get_variable('w_f_2',[3,3,16,32])
    b_f_2=get_variable('b_f_2',[32])
    w_f_3=get_variable('w_f_3',[3,3,32,64])
    b_f_3=get_variable('b_f_3',[64])
    w_f_4=get_variable('w_f_4', [3, 3, 64, 128])
    b_f_4=get_variable('b_f_4', [128])
    w_fc_5=get_variable('w_fc_5',[128,2])
    b_fc_5=get_variable('b_fc_5',[2])
    X_data=tf.reshape(X_data,[batch_size,64,64,3])
    X_generate=tf.reshape(X_generate,[batch_size,64,64,3])
    x_input=tf.concat([X_data,X_generate],0)
    with tf.variable_scope("conv1"):
        net=tf.nn.conv2d(x_input,filter=w_f_1,strides=[1,1,1,1],padding='SAME')
        net=tf.nn.bias_add(net,b_f_1)
        net=tf.nn.relu(net)
        net=tf.nn.max_pool(net,ksize=[1,2,2,1],strides=[1,2,2,1],padding='SAME')
    with tf.variable_scope('conv2'):
        #32*32*16
        net=tf.nn.conv2d(net,filter=w_f_2,strides=[1,1,1,1],padding='SAME')
        net=tf.nn.bias_add(net,b_f_2)
        net=tf.nn.relu(net)
        net=tf.nn.max_pool(net,ksize=[1,2,2,1],strides=[1,2,2,1],padding='SAME')
    with tf.variable_scope('conv3'):
        #16*16*32
        net=tf.nn.conv2d(net,filter=w_f_3,strides=[1,1,1,1],padding='SAME')
        net=tf.nn.bias_add(net,b_f_3)
        net=tf.nn.relu(net)
        net=tf.nn.max_pool(net,ksize=[1,2,2,1],strides=[1,2,2,1],padding='SAME')
    with tf.variable_scope('conv4'):
        #8*8*64
        net = tf.nn.conv2d(net, filter=w_f_4, strides=[1, 1, 1, 1], padding='SAME')
        net = tf.nn.bias_add(net,b_f_4 )
        net = tf.nn.relu(net)
    with tf.variable_scope('MAX_average_pool5'):
        #8*8*128
        net=tf.nn.avg_pool(net,ksize=[1,8,8,1],strides=[1,8,8,1],padding='SAME')
        #1*1*128
    with tf.variable_scope('fc6'):
        net=tf.reshape(net,[-1,128])
        net=tf.matmul(net,w_fc_5)+b_fc_5
    #==> [batch_size*2,2]
    Y_data=tf.nn.softmax(tf.slice(net,[0,0],[batch_size,-1]))
    Y_generate=tf.nn.softmax(tf.slice(net,[batch_size,0],[-1,-1]))
    print(Y_data.get_shape())
    print(Y_generate.get_shape())
    return Y_data,Y_generate,[w_f_1,b_f_1,w_f_2,b_f_2,w_f_3,b_f_3,w_f_4,b_f_4,w_fc_5,b_fc_5]

#定义读取图片数据函数
def read_message(classify_name='樱花',text_dir='text_data'):
    os.chdir(text_dir+'/')
    print('正在读取',classify_name)
    array=np.loadtxt(classify_name+'.txt',dtype=np.float32)
    os.chdir('../')
    return array

#图片显示函数
def show_photos(file_name,photo_data):
    #将photo_data的三张图片组合成一张图片
    print(photo_data)
    photo_data=np.reshape(photo_data,[-1,64,64,3]).astype(np.float32)
    #_,img_h,img_w,_=photo_data.shape
    #定义复合图片信息
    #grid_h=img_h
    #grid_w=img_w*3+5*(3-1)
    numb = photo_data.shape[0]
    #img_grid = np.zeros((grid_h, grid_w), dtype=np.uint8)
    for index in range(numb):

        print(photo_data[index])
        photo_data[index]=np.array(photo_data[index])

        imsave(file_name+"_{}.png".format(index),photo_data[index])

#主函数
def main():

    #1.数据生成
    sess = tf.InteractiveSession()
    #将图片保存为文件数据
    judge_=judge_message_exist("text_data")
    if judge_:
        pass
    else:
        test_folder_is_exist("text_data")
        list_images=images_transform("./data")
        print(list_images)
        save_image_message(list_images,"text_data","樱花",data_path='data/')
    sess.close()
    #2.构建网络占位符
    #生成模型占位符
    X_G=tf.placeholder(tf.float32,[None,100],name='X_G')
    #判别模型占位符
    X_D=tf.placeholder(tf.float32,[None,64*64*3],name='X_D')
    global_step = tf.Variable(0, name='global_step', trainable=False)
    #3.构建生成模型
    Gt_X, Gt_params = set_generate_model(X_G)
    #4.构建判别模型
    Dc_Y_data,Dc_Y_generate,Dc_params=set_discriminant_model(X_D,Gt_X)
    # 5.损失函数
    # 设1,0 为真实值
    # 设0,1 为判假值
    # Dc_Y_data_arg_max =[ Dc_Y_data[i,tf.argmax(Dc_Y_data, 1)[i]] for i in range(batch_size)]
    Dc_Y_data_arg_max = Dc_Y_data[0][tf.argmax(Dc_Y_data, 1)[0]] + Dc_Y_data[1][tf.argmax(Dc_Y_data, 1)[1]] + \
                        Dc_Y_data[2][tf.argmax(Dc_Y_data, 1)[2]]
    Dc_Y_generate_arg_min = (1 - Dc_Y_generate[0][tf.argmin(Dc_Y_generate, 1)[0]]) + (
    1 - Dc_Y_generate[1][tf.argmin(Dc_Y_generate, 1)[1]]) + (1 - Dc_Y_generate[2][tf.argmin(Dc_Y_generate, 1)[2]])
    d_loss = -(tf.log(tf.reduce_sum(Dc_Y_data_arg_max)) + tf.log(tf.reduce_sum(Dc_Y_generate_arg_min)))
    g_loss = -tf.log(Dc_Y_generate_arg_min)
    # 6.定义优化函数
    optimizer = tf.train.AdamOptimizer(0.0002)
    # var_list作用域选择更新指定参数列表
    d_train = optimizer.minimize(d_loss, var_list=Gt_params)
    g_train = optimizer.minimize(g_loss, var_list=Dc_params)
    saver=tf.train.Saver()
    #7.开始模型训练

    #os.chdir('generate_flowers/')
    with tf.Session() as sess:
        # 变量初始化
        sess.run(tf.global_variables_initializer())
        ckpt=None
        if True:
            # 加载模型继续训练
            ckpt = tf.train.latest_checkpoint(checkpoint_dir)
            if ckpt:
                print("load model ......")
                saver.restore(sess, ckpt)


        #构建随机噪声
        Gt_random_noise=np.random.normal(0,1,[batch_size,100]).astype(np.float32)
        steps=int(1000/batch_size)
        flowers = read_message()
        print(os.getcwd())
        for i in range(sess.run(global_step),max_epoch):
            index = np.random.permutation(1000)
            for j in range(steps):

                print("Epoch:{},step:{}".format(i,j))
                #获取batch_size的图片数据
                Dc_value=flowers[index[j*3:(j+1)*3]]
                #获取batch_size个噪音数据
                Gt_random_noise_two=np.random.normal(0,1,[batch_size,100]).astype(np.float32)

                #执行判别训练
                sess.run(d_train,feed_dict={X_D:Dc_value,X_G:Gt_random_noise_two})
                #执行生成训练
                sess.run(g_train,feed_dict={X_D:Dc_value,X_G:Gt_random_noise_two})

            #执行一个epoch,输出一次信息
            generate_value=sess.run(Gt_X,feed_dict={X_G:Gt_random_noise})
            #generate_params=sess.run(Gt_params,feed_dict={X_G:Gt_random_noise})
            #print("params:{}".format(generate_params))
            print(os.getcwd())
            os.chdir('output_photos')
            show_photos('generate_p{}'.format(i),generate_value)
            os.chdir('../')

            # 每完成一个迭代,将模型保存一次
            sess.run(tf.assign(global_step, i + 1))
            saver.save(sess, os.path.join(checkpoint_dir, 'model'), global_step=global_step)

if __name__=="__main__":
    main()

迭代至600余次,图片展示如下,训练完成后会提供展示最终效果:
这里写图片描述

四、总结与分析
1.本篇代码用了大概5天时间来编写,效率异常低下,但是动手从头到尾的敲一遍,对GAN的知识的理解确实加强了不少,总的来说还是比较值得的。
2.本篇代码的一些计算方式是有问题的,这样不符合一个基本程序员的素养,但是因为赶进度也就这样了,其中损失函数的计算要求对batch_size的个数是固定的(怕问题不仅如此),因为其值是根据交叉熵损失函数进行的推理,可能会有一些问题,希望聪明的读者可以指正和交流。
3.GAN虽然不是相当成熟,但是现在已经可以达到一个重要的研究和学习的点,并且其扩充了CNN和RNN的能力和功能,以期于完成更加智能的使命。所以对于该领域的学习与深究是非常有必要的一件事情。

评论 16
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值