一、前言
本篇将使用GAN(生成对抗网络)生成一类花卉图片的训练样本,也就是说我们现在拥有1000张樱花的图片,但是数据集的量太少导致分类模型的判别能力较弱,而且容易导致过拟合等问题,所以这里我们需要增加训练样本,但是实地采集又太麻烦并且爬虫得到的数据冗杂数据又太多,这个时候该怎么办呢?GAN的双网络并行训练的机制解决了这个问题。我们此案例就是依靠GAN的特点来生成更多的图片来扩充整个数据集合。
二、概念概述
GAN由generate(生成式模型)和discriminator(判别式模型)两部分构成。
1.generate:主要是从训练数据中产生相同分布的samples,对于输入x,类别标签y,在生成式模型中估计其联合概率分布(两个及以上随机变量组成的随机向量的概率分布)。
2.discriminator:判断输入是真实数据还是generate生成的数据,即估计样本属于某类的条件概率分布。它采用传统的监督学习方法。
两个模型结合后,经过大量次数的迭代训练会使generate尽可能模拟出以假乱真的样本,而discriminator会有更精确的鉴别真伪数据的能力,最终会达到博弈论中的整体纳什均衡,即discriminator对于generate的数据鉴别结果为正确率和错误率各占50%。
三、具体实现
# _*_ coding:utf-8 _*_
"""
Created by 南枫木木 2018/9/7
"""
import tensorflow as tf
import numpy as np
import os
import math
import PIL.Image as Image
import matplotlib.pyplot as plt
#图片保存api
from skimage.io import imsave
import pylab
batch_size=3
#train_epochs=10000
max_epoch=500
checkpoint_dir='model/'
#读取并标准化图片信息,返回图片tensor列表
def images_transform(image_dir_path,size=(64,64),channels=3):
'''
Picture conversion
:param image_path:
:param size:
:param channels:
:return:
'''
list_images=os.listdir(image_dir_path)
list_image_tensors=[]
os.chdir(image_dir_path)
for i in range(len(list_images)):
#file_contents=tf.read_file(list_images[i])
#读取图片信息
print(list_images[i])
file_contents=tf.gfile.FastGFile(list_images[i],'rb').read()
#将二进制图片信息解码
try:
image_tensor=tf.image.decode_jpeg(file_contents,channels=channels)
except BaseException :
image_tensor=tf.image.decode_image(file_contents,channels=channels)
#print(image_tensor)
#print(list_images[i])
#将图片进行标准化
image_tensor=tf.image.resize_images(image_tensor,size=size,method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
list_image_tensors.append(image_tensor)
#返回转化后的tensor列表
return list_image_tensors
#判断是否存在文件夹
def test_folder_is_exist(text_folder):
if os.path.exists(text_folder):
print(os.listdir('./'))
else:
os.mkdir(text_folder)
print(os.listdir('./'))
#保存图片数据
def save_image_message(save_tensor,text_folder,classify_name,data_path='data'):
#进入DATA的同级目录
print(os.getcwd())
os.chdir("../"+text_folder)
#将tensor reshape为二维数组(1000,150528)
save_tensor=tf.reshape(save_tensor,shape=(-1,12288))
#将tensor转换为numpy数组
save_array=save_tensor.eval()
#保存array为text文件
np.savetxt(classify_name+'.txt',save_array)
#判断是否保存过图片数据
def judge_message_exist(text_folder):
judge_=os.path.exists(text_folder)
return judge_
#定义参数初始化方式
def get_variable(name,shape,dtype=tf.float32,initializer=tf.random_normal_initializer(0,0.1)):
return tf.get_variable(name,shape,dtype=dtype,initializer=initializer)
#定义生成模型
def set_generate_model(Noise_Data):
"""
一维服从正态分布的噪音数据通过反卷积生成原始图片信息
:param Noise_Data: 一维噪音数据
:return: 图片信息
(batch_size, 1, 1, 1024)
(batch_size, 2, 2, 256)
(batch_size, 2, 2, 256)
(batch_size, 4, 4, 128)
(batch_size, 4, 4, 128)
(batch_size, 8, 8, 32)
(batch_size, 8, 8, 32)
(batch_size, 16, 16, 16)
(batch_size, 16, 16, 16)
(batch_size, 32, 32, 8)
(batch_size, 32, 32, 8)
(batch_size, 64, 64, 3)
"""
fc1=1024 #1*1*1024
Noise_size = Noise_Data.get_shape()[1]
Noise_Data = tf.reshape(Noise_Data, [batch_size, -1])
#fc层参数
w1=get_variable('w1',shape=[Noise_size,fc1])
b1=get_variable('b1',shape=[fc1])
w2=get_variable('w2',[3,3,256,1024])
b2=get_variable('b2',[256])
w3=get_variable('w3',[3,3,128,256])
b3=get_variable('b3',[128])
w4=get_variable('w4',[3,3,32,128])
b4=get_variable('b4',[32])
w5=get_variable('w5',[3,3,16,32])
b5=get_variable('b5',[16])
w6=get_variable('w6',[3,3,8,16])
b6=get_variable('b6',[8])
w7=get_variable('w7', [3, 3, 3, 8])
b7=get_variable('b7', [3])
with tf.variable_scope('G_tans_fc1'):
net=tf.matmul(Noise_Data,w1)+b1
net=tf.reshape(net,[-1,1,1,fc1])
with tf.variable_scope('G_tans_Conv2'):
print(net.get_shape())
net=tf.nn.conv2d_transpose(net,filter=w2,output_shape=[batch_size,2,2,256],strides=[1,2,2,1],padding='SAME')
net=tf.nn.bias_add(net,b2)
print(net.get_shape())
net=tf.nn.leaky_relu(net)
with tf.variable_scope('G_tans_Conv3'):
print(net.get_shape())
net=tf.nn.conv2d_transpose(net,filter=w3,output_shape=[batch_size,4,4,128],strides=[1,2,2,1],padding='SAME')
net=tf.nn.bias_add(net,b3)
print(net.get_shape())
net=tf.nn.leaky_relu(net)
with tf.variable_scope('G_tans_Conv4'):
print(net.get_shape())
net=tf.nn.conv2d_transpose(net,filter=w4,output_shape=[batch_size,8,8,32],strides=[1,2,2,1],padding='SAME')
net=tf.nn.bias_add(net,b4)
print(net.get_shape())
net=tf.nn.leaky_relu(net)
with tf.variable_scope('G_tans_Conv5'):
print(net.get_shape())
net=tf.nn.conv2d_transpose(net,filter=w5,output_shape=[batch_size,16,16,16],strides=[1,2,2,1],padding='SAME')
net=tf.nn.bias_add(net,b5)
print(net.get_shape())
net=tf.nn.leaky_relu(net)
with tf.variable_scope('G_tans_Conv6'):
print(net.get_shape())
net=tf.nn.conv2d_transpose(net,filter=w6,output_shape=[batch_size,32,32,8],strides=[1,2,2,1],padding='SAME')
net=tf.nn.bias_add(net,b6)
print(net.get_shape())
net=tf.nn.leaky_relu(net)
with tf.variable_scope('G_tans_Conv7'):
print(net.get_shape())
net = tf.nn.conv2d_transpose(net, filter=w7, output_shape=[batch_size, 64, 64, 3],
strides=[1, 2, 2, 1], padding='SAME')
net = tf.nn.bias_add(net, b7)
print(net.get_shape())
net = tf.nn.tanh(net)
return net,[w1,b1,w2,b2,w3,b3,w4,b4,w5,b5,w6,b6,w7,b7]
#定义判别模型
def set_discriminant_model(X_data,X_generate):
w_f_1=get_variable('w_f_1',[3,3,3,16])
b_f_1=get_variable('b_f_1',[16])
w_f_2=get_variable('w_f_2',[3,3,16,32])
b_f_2=get_variable('b_f_2',[32])
w_f_3=get_variable('w_f_3',[3,3,32,64])
b_f_3=get_variable('b_f_3',[64])
w_f_4=get_variable('w_f_4', [3, 3, 64, 128])
b_f_4=get_variable('b_f_4', [128])
w_fc_5=get_variable('w_fc_5',[128,2])
b_fc_5=get_variable('b_fc_5',[2])
X_data=tf.reshape(X_data,[batch_size,64,64,3])
X_generate=tf.reshape(X_generate,[batch_size,64,64,3])
x_input=tf.concat([X_data,X_generate],0)
with tf.variable_scope("conv1"):
net=tf.nn.conv2d(x_input,filter=w_f_1,strides=[1,1,1,1],padding='SAME')
net=tf.nn.bias_add(net,b_f_1)
net=tf.nn.relu(net)
net=tf.nn.max_pool(net,ksize=[1,2,2,1],strides=[1,2,2,1],padding='SAME')
with tf.variable_scope('conv2'):
#32*32*16
net=tf.nn.conv2d(net,filter=w_f_2,strides=[1,1,1,1],padding='SAME')
net=tf.nn.bias_add(net,b_f_2)
net=tf.nn.relu(net)
net=tf.nn.max_pool(net,ksize=[1,2,2,1],strides=[1,2,2,1],padding='SAME')
with tf.variable_scope('conv3'):
#16*16*32
net=tf.nn.conv2d(net,filter=w_f_3,strides=[1,1,1,1],padding='SAME')
net=tf.nn.bias_add(net,b_f_3)
net=tf.nn.relu(net)
net=tf.nn.max_pool(net,ksize=[1,2,2,1],strides=[1,2,2,1],padding='SAME')
with tf.variable_scope('conv4'):
#8*8*64
net = tf.nn.conv2d(net, filter=w_f_4, strides=[1, 1, 1, 1], padding='SAME')
net = tf.nn.bias_add(net,b_f_4 )
net = tf.nn.relu(net)
with tf.variable_scope('MAX_average_pool5'):
#8*8*128
net=tf.nn.avg_pool(net,ksize=[1,8,8,1],strides=[1,8,8,1],padding='SAME')
#1*1*128
with tf.variable_scope('fc6'):
net=tf.reshape(net,[-1,128])
net=tf.matmul(net,w_fc_5)+b_fc_5
#==> [batch_size*2,2]
Y_data=tf.nn.softmax(tf.slice(net,[0,0],[batch_size,-1]))
Y_generate=tf.nn.softmax(tf.slice(net,[batch_size,0],[-1,-1]))
print(Y_data.get_shape())
print(Y_generate.get_shape())
return Y_data,Y_generate,[w_f_1,b_f_1,w_f_2,b_f_2,w_f_3,b_f_3,w_f_4,b_f_4,w_fc_5,b_fc_5]
#定义读取图片数据函数
def read_message(classify_name='樱花',text_dir='text_data'):
os.chdir(text_dir+'/')
print('正在读取',classify_name)
array=np.loadtxt(classify_name+'.txt',dtype=np.float32)
os.chdir('../')
return array
#图片显示函数
def show_photos(file_name,photo_data):
#将photo_data的三张图片组合成一张图片
print(photo_data)
photo_data=np.reshape(photo_data,[-1,64,64,3]).astype(np.float32)
#_,img_h,img_w,_=photo_data.shape
#定义复合图片信息
#grid_h=img_h
#grid_w=img_w*3+5*(3-1)
numb = photo_data.shape[0]
#img_grid = np.zeros((grid_h, grid_w), dtype=np.uint8)
for index in range(numb):
print(photo_data[index])
photo_data[index]=np.array(photo_data[index])
imsave(file_name+"_{}.png".format(index),photo_data[index])
#主函数
def main():
#1.数据生成
sess = tf.InteractiveSession()
#将图片保存为文件数据
judge_=judge_message_exist("text_data")
if judge_:
pass
else:
test_folder_is_exist("text_data")
list_images=images_transform("./data")
print(list_images)
save_image_message(list_images,"text_data","樱花",data_path='data/')
sess.close()
#2.构建网络占位符
#生成模型占位符
X_G=tf.placeholder(tf.float32,[None,100],name='X_G')
#判别模型占位符
X_D=tf.placeholder(tf.float32,[None,64*64*3],name='X_D')
global_step = tf.Variable(0, name='global_step', trainable=False)
#3.构建生成模型
Gt_X, Gt_params = set_generate_model(X_G)
#4.构建判别模型
Dc_Y_data,Dc_Y_generate,Dc_params=set_discriminant_model(X_D,Gt_X)
# 5.损失函数
# 设1,0 为真实值
# 设0,1 为判假值
# Dc_Y_data_arg_max =[ Dc_Y_data[i,tf.argmax(Dc_Y_data, 1)[i]] for i in range(batch_size)]
Dc_Y_data_arg_max = Dc_Y_data[0][tf.argmax(Dc_Y_data, 1)[0]] + Dc_Y_data[1][tf.argmax(Dc_Y_data, 1)[1]] + \
Dc_Y_data[2][tf.argmax(Dc_Y_data, 1)[2]]
Dc_Y_generate_arg_min = (1 - Dc_Y_generate[0][tf.argmin(Dc_Y_generate, 1)[0]]) + (
1 - Dc_Y_generate[1][tf.argmin(Dc_Y_generate, 1)[1]]) + (1 - Dc_Y_generate[2][tf.argmin(Dc_Y_generate, 1)[2]])
d_loss = -(tf.log(tf.reduce_sum(Dc_Y_data_arg_max)) + tf.log(tf.reduce_sum(Dc_Y_generate_arg_min)))
g_loss = -tf.log(Dc_Y_generate_arg_min)
# 6.定义优化函数
optimizer = tf.train.AdamOptimizer(0.0002)
# var_list作用域选择更新指定参数列表
d_train = optimizer.minimize(d_loss, var_list=Gt_params)
g_train = optimizer.minimize(g_loss, var_list=Dc_params)
saver=tf.train.Saver()
#7.开始模型训练
#os.chdir('generate_flowers/')
with tf.Session() as sess:
# 变量初始化
sess.run(tf.global_variables_initializer())
ckpt=None
if True:
# 加载模型继续训练
ckpt = tf.train.latest_checkpoint(checkpoint_dir)
if ckpt:
print("load model ......")
saver.restore(sess, ckpt)
#构建随机噪声
Gt_random_noise=np.random.normal(0,1,[batch_size,100]).astype(np.float32)
steps=int(1000/batch_size)
flowers = read_message()
print(os.getcwd())
for i in range(sess.run(global_step),max_epoch):
index = np.random.permutation(1000)
for j in range(steps):
print("Epoch:{},step:{}".format(i,j))
#获取batch_size的图片数据
Dc_value=flowers[index[j*3:(j+1)*3]]
#获取batch_size个噪音数据
Gt_random_noise_two=np.random.normal(0,1,[batch_size,100]).astype(np.float32)
#执行判别训练
sess.run(d_train,feed_dict={X_D:Dc_value,X_G:Gt_random_noise_two})
#执行生成训练
sess.run(g_train,feed_dict={X_D:Dc_value,X_G:Gt_random_noise_two})
#执行一个epoch,输出一次信息
generate_value=sess.run(Gt_X,feed_dict={X_G:Gt_random_noise})
#generate_params=sess.run(Gt_params,feed_dict={X_G:Gt_random_noise})
#print("params:{}".format(generate_params))
print(os.getcwd())
os.chdir('output_photos')
show_photos('generate_p{}'.format(i),generate_value)
os.chdir('../')
# 每完成一个迭代,将模型保存一次
sess.run(tf.assign(global_step, i + 1))
saver.save(sess, os.path.join(checkpoint_dir, 'model'), global_step=global_step)
if __name__=="__main__":
main()
迭代至600余次,图片展示如下,训练完成后会提供展示最终效果:
四、总结与分析
1.本篇代码用了大概5天时间来编写,效率异常低下,但是动手从头到尾的敲一遍,对GAN的知识的理解确实加强了不少,总的来说还是比较值得的。
2.本篇代码的一些计算方式是有问题的,这样不符合一个基本程序员的素养,但是因为赶进度也就这样了,其中损失函数的计算要求对batch_size的个数是固定的(怕问题不仅如此),因为其值是根据交叉熵损失函数进行的推理,可能会有一些问题,希望聪明的读者可以指正和交流。
3.GAN虽然不是相当成熟,但是现在已经可以达到一个重要的研究和学习的点,并且其扩充了CNN和RNN的能力和功能,以期于完成更加智能的使命。所以对于该领域的学习与深究是非常有必要的一件事情。