tensorflow.keras搭建gan神经网络,可直接运行

tensorflow.keras搭建gan神经网络,可直接运行


前言

keras是tensorflow的一个高级API库之一,代码简洁,可读性强。本文采用tensorflow.keras来实现gan网络。具体的原理在本文不作过多阐述,只作为一个案例交流

#####keras中文参考文档


正文

一、tf.keras搭建gan网络大致步骤

1.首先我们需要将所有的图像数据装换为tensorflow提供的tfrecords的格式,利用creat_tfrecords.py文件生成即可(这个文件是我原来用作图像分类的标签生成的脚本文件,如果做gan网络不需要将标签也保存)
2.利用生成的tfrecords文件来建立数据集,利用tf.data.TFRecordDataset来进行设置,本文还提供了另一种方法来对tfrecords数据进行获取,但是殊途同归,方法都差不多
3.搭建generator网络
4.搭建discriminator网络,整合为gan网络(需要在gan网络compile之前将discriminator网络设置为不可训练)
5.建立循环体分别训练generator网络和discriminator网络
6.保存网络gan.model

二、使用步骤

1.制作tfrecords数据集

creat_tfrecords.py
默认生成tfrecords位置为 filename_train="./data/train.tfrecords"
终端输入:python creat_tfrecords.py --data [数据集位置]
生成train.tfrecords,也可以自己动手添加验证集和测试集的数据

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt 
import os
from PIL import Image
import random

objects = ['cat','dog']#'cat'0,'dog'1

filename_train="./data/train.tfrecords"
writer_train= tf.python_io.TFRecordWriter(filename_train)

tf.app.flags.DEFINE_string(
    'data', 'None', 'where the datas?.')
FLAGS = tf.app.flags.FLAGS

if(FLAGS.data == None):
    os._exit(0)

dim = (224,224)
object_path = FLAGS.data
total = os.listdir(object_path)
for index in total:
    img_path=os.path.join(object_path,index)
    img=Image.open(img_path)
    img=img.resize(dim)
    img_raw=img.tobytes()
    for i in range(len(objects)):
        if objects[i] in index:
            value = i
        else:
            continue
    example = tf.train.Example(features=tf.train.Features(feature={
                'label': tf.train.Feature(int64_list=tf.train.Int64List(value=[value])),
                'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw]))
                }))
    print([index,value])
    writer_train.write(example.SerializeToString())  #序列化为字符串
writer_train.close()

2.读入数据

利用tf.data.TFRecordDataset建立
代码如下:(load_image函数用来作为map的输入,对数据集进行解码),在main函数中调用:
train_datas,iter = dataset_tfrecords(tfrecords_path,use_keras_fit=False)

def load_image(serialized_example):   
    features={
        'label': tf.io.FixedLenFeature([], tf.int64),
        'img_raw' : tf.io.FixedLenFeature([], tf.string)}
    parsed_example = tf.io.parse_example(serialized_example,features)
    image = tf.decode_raw(parsed_example['img_raw'],tf.uint8)
    image = tf.reshape(image,[-1,224,224,3])
    image = tf.cast(image,tf.float32)*(1./255)
    label = tf.cast(parsed_example['label'], tf.int32)
    label = tf.reshape(label,[-1,1])
    return image,label
 
def dataset_tfrecords(tfrecords_path,use_keras_fit=True): 
    #是否使用tf.keras
    if use_keras_fit:
        epochs_data = 1
    else:
        epochs_data = epochs
    dataset = tf.data.TFRecordDataset([tfrecords_path])#这个可以有多个组成[tfrecords_name1,tfrecords_name2,...],可以用os.listdir(tfrecords_path):
    dataset = dataset\
                .repeat(epochs_data)\
                .shuffle(1000)\ 
                .batch(batch_size)\
                .map(load_image,num_parallel_calls = 2)
                #注意一定要将shuffle操作放在batch前
                

    iter = dataset.make_initializable_iterator()#make_one_shot_iterator()
    train_datas = iter.get_next() #用train_datas[0],[1]的方式得到值
    return train_datas,iter

3.搭建gan网络

a.搭建generator网络

    generator = keras.models.Sequential([
            #fullyconnected nets
            keras.layers.Dense(256,activation='selu',input_shape=[coding_size]),
            keras.layers.Dense(64,activation='selu'),
            keras.layers.Dense(256,activation='selu'),
            keras.layers.Dense(1024,activation='selu'),
            keras.layers.Dense(7*7*64,activation='selu'),
            keras.layers.Reshape([7,7,64]),
            #7*7*64
            #反卷积
            keras.layers.Conv2DTranspose(64,kernel_size=3,strides=2,padding='same',activation='selu'),
            #14*14*64
            keras.layers.Conv2DTranspose(64,kernel_size=3,strides=2,padding='same',activation='selu'),
            #28*28*64
            keras.layers.Conv2DTranspose(32,kernel_size=3,strides=2,padding='same',activation='selu'),
            #56*56*32
            keras.layers.Conv2DTranspose(16,kernel_size=3,strides=2,padding='same',activation='selu'),
            #112*112*16
            keras.layers.Conv2DTranspose(3,kernel_size=3,strides=2,padding='same',activation='tanh'),#使用tanh代替sigmoid
            #224*224*3
            keras.layers.Reshape([224,224,3])
            ])

b.搭建discriminator网络

    discriminator = keras.models.Sequential([
            keras.layers.Conv2D(128,kernel_size=3,padding='same',strides=2,activation='selu',input_shape=[224,224,3]),
            keras.layers.MaxPool2D(pool_size=2),
            #56*56*128
            keras.layers.Conv2D(64,kernel_size=3,padding='same',strides=2,activation='selu'),
            keras.layers.MaxPool2D(pool_size=2),
            #14*14*64
            keras.layers.Conv2D(32,kernel_size=3,padding='same',strides=2,activation='selu'),
            #7*7*32
            keras.layers.Flatten(),
            #dropout 0.4
            keras.layers.Dropout(0.4),
            keras.layers.Dense(512,activation='selu'),
            keras.layers.Dropout(0.4),
            keras.layers.Dense(64,activation='selu'),
            keras.layers.Dropout(0.4),
            #the last net
            keras.layers.Dense(1,activation='sigmoid')
            ])

c.整合generator,discriminator网络为gan网络

gan = keras.models.Sequential([generator,discriminator])

4.complie编译(建立loss和optimizer优化器)

    #compile the net
    discriminator.compile(loss="binary_crossentropy",optimizer='rmsprop')# metrics=['accuracy'])
    discriminator.trainable=False
    gan.compile(loss="binary_crossentropy",optimizer='rmsprop')# metrics=['accuracy'])

5.训练网络(建立循环)

获取数据集:

train_datas,iter = dataset_tfrecords(tfrecords_path,use_keras_fit=False)

循环体:(在里面使用cv2来对generator网络查看)

    sess = tf.Session()
    sess.run(iter.initializer)
    #打开线程协调器
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=sess,coord=coord)
    
    generator,discriminator = gan.layers
    print("-----------------start---------------")
    for step in range(num_steps):
        try:
            #get the time
            start_time = time.time()
            #phase 1 - training the discriminator
            noise = np.random.normal(size=batch_size*coding_size).reshape([batch_size,coding_size])
            noise = np.cast[np.float32](noise)
            generated_images = generator.predict(noise)
            train_datas_ = sess.run(train_datas)
            x_fake_and_real = np.concatenate([generated_images,train_datas_[0]],axis = 0)#np.concatenate
            #千万不能再循环体内用tf.concat,不能用tf相关的函数在循环体内定义
            #否则内存会被耗尽,而且训练速度越来越慢
            y1 = np.array([[0.]]*batch_size+[[1.]]*len(train_datas_[0]))
            discriminator.trainable = True
            dis_loss = discriminator.train_on_batch(x_fake_and_real,y1)
            #将keras 的train_on_batch函数放在gan网络中是明智之举
            #phase 2 - training the generator
            noise = np.random.normal(size=batch_size*coding_size).reshape([batch_size,coding_size])
            noise = np.cast[np.float32](noise)
            y2 = np.array([[1.]]*batch_size)
            discriminator.trainable = False
            ad_loss = gan.train_on_batch(noise,y2)
            duration = time.time()-start_time
            if step % 5 == 0:
                #gan.save_weights('gan.h5')
                print("The step is %d,discriminator loss:%.3f,adversarial loss:%.3f"%(step,dis_loss,ad_loss),end=' ')
                print('%.2f s/step'%(duration))
            if step % 30 == 0 and step != 0:
                noise = np.random.normal(size=[1,coding_size])
                noise = np.cast[np.float32](noise)
                fake_image = generator.predict(noise,steps=1)
                #复原图像
                #1.乘以255后需要映射成uint8的类型
                #2.也可以保持[0,1]的float32类型,依然可以直接输出
                arr_img = np.array([fake_image],np.float32).reshape([224,224,3])*255
                arr_img = np.cast[np.uint8](arr_img)
                #保存为tfrecords用的是PIL.Image,即打开为RGB,所以在用cv显示时需要转换为BGR
                arr_img = cv2.cvtColor(arr_img,cv2.COLOR_RGB2BGR)
                cv2.imshow('fake image',arr_img)
                cv2.waitKey(1500)#show the fake image 1.5s
                cv2.destroyAllWindows()
        #在迭代完以后会抛出这个错误OutOfRangeError,所以需要将迭代器初始化
        except tf.errors.OutOfRangeError: 
            sess.run(iter.initializer)
    #关闭线程协调器
    coord.request_stop()
    coord.join(threads)

6.保存网络

	#tensorflow 2.0版本
    #save the models 
    model_vision = '0001'
    model_name = 'gans'
    model_path = os.path.join(model_name,model_name)
    tf.saved_model.save(gan,model_path)
    #tensorflow 1.13.1版本
    #save the models 
    model_vision = '0001'
    gan.save_weights(model_vision)

7.完整的gans.py(可运行)

# -*- coding: utf-8 -*-
'''
    @author:zyl
    author is zouyuelin
    a Master of Tianjin University(TJU)
'''

import tensorflow as tf
from tensorflow import keras
#tf.enable_eager_execution()
import numpy as np
from PIL import Image
import os
import cv2
import time

batch_size = 32
epochs = 120
num_steps = 2000
coding_size = 30
tfrecords_path = 'data/train.tfrecords'

#--------------------------------------datasetTfrecord----------------   
def load_image(serialized_example):   
    features={
        'label': tf.io.FixedLenFeature([], tf.int64),
        'img_raw' : tf.io.FixedLenFeature([], tf.string)}
    parsed_example = tf.io.parse_example(serialized_example,features)
    image = tf.decode_raw(parsed_example['img_raw'],tf.uint8)
    image = tf.reshape(image,[-1,224,224,3])
    image = tf.cast(image,tf.float32)*(1./255)
    label = tf.cast(parsed_example['label'], tf.int32)
    label = tf.reshape(label,[-1,1])
    return image,label
 
def dataset_tfrecords(tfrecords_path,use_keras_fit=True): 
    #是否使用tf.keras
    if use_keras_fit:
        epochs_data = 1
    else:
        epochs_data = epochs
    dataset = tf.data.TFRecordDataset([tfrecords_path])#这个可以有多个组成[tfrecords_name1,tfrecords_name2,...],可以用os.listdir(tfrecords_path):
    dataset = dataset\
                .repeat(epochs_data)\
                .shuffle(1000)\ 
                .batch(batch_size)\
                .map(load_image,num_parallel_calls = 2)
                #注意一定要将shuffle操作放在batch前

    iter = dataset.make_initializable_iterator()#make_one_shot_iterator()
    train_datas = iter.get_next() #用train_datas[0],[1]的方式得到值
    return train_datas,iter

#------------------------------------tf.TFRecordReader-----------------
def read_and_decode(tfrecords_path):
    #根据文件名生成一个队列
    filename_queue = tf.train.string_input_producer([tfrecords_path],shuffle=True) 
    reader = tf.TFRecordReader()
    _,  serialized_example = reader.read(filename_queue)
    features = tf.parse_single_example(serialized_example,features={
        'label': tf.FixedLenFeature([], tf.int64),
        'img_raw' : tf.FixedLenFeature([], tf.string)})

    image = tf.decode_raw(features['img_raw'], tf.uint8)
    image = tf.reshape(image,[224,224,3])#reshape 200*200*3
    image = tf.cast(image,tf.float32)*(1./255)#image张量可以除以255,*(1./255)
    label = tf.cast(features['label'], tf.int32)
    img_batch, label_batch = tf.train.shuffle_batch([image,label],
                    batch_size=batch_size,
                    num_threads=4,
                    capacity= 640,
                    min_after_dequeue=5)
    return [img_batch,label_batch]

#Autodecode 解码器
def autoencode():
        encoder = keras.models.Sequential([
            keras.layers.Conv2D(32,kernel_size=3,padding='same',strides=2,activation='selu',input_shape=[224,224,3]),
            #112*112*32
            keras.layers.MaxPool2D(pool_size=2),
            #56*56*32
            keras.layers.Conv2D(64,kernel_size=3,padding='same',strides=2,activation='selu'),
            #28*28*64
            keras.layers.MaxPool2D(pool_size=2),
            #14*14*64
            keras.layers.Conv2D(128,kernel_size=3,padding='same',strides=2,activation='selu'),
            #7*7*128
            #反卷积
            keras.layers.Conv2DTranspose(128,kernel_size=3,strides=2,padding='same',activation='selu'),
            #14*14*128
            keras.layers.Conv2DTranspose(64,kernel_size=3,strides=2,padding='same',activation='selu'),
            #28*28*64
            keras.layers.Conv2DTranspose(32,kernel_size=3,strides=2,padding='same',activation='selu'),
            #56*56*32
            keras.layers.Conv2DTranspose(16,kernel_size=3,strides=2,padding='same',activation='selu'),
            #112*112*16
            keras.layers.Conv2DTranspose(3,kernel_size=3,strides=2,padding='same',activation='tanh'),#使用tanh代替sigmoid
            #224*224*3
            keras.layers.Reshape([224,224,3])
            ])
        return encoder

def training_keras():
    '''
        卷积和池化输出公式:
            output_size = (input_size-kernel_size+2*padding)/strides+1
            
        keras的反卷积输出计算,一般不用out_padding
        1.若padding = 'valid':
            output_size = (input_size - 1)*strides + kernel_size
        2.若padding = 'same:
            output_size = input_size * strides
    '''
    generator = keras.models.Sequential([
            #fullyconnected nets
            keras.layers.Dense(256,activation='selu',input_shape=[coding_size]),
            keras.layers.Dense(64,activation='selu'),
            keras.layers.Dense(256,activation='selu'),
            keras.layers.Dense(1024,activation='selu'),
            keras.layers.Dense(7*7*64,activation='selu'),
            keras.layers.Reshape([7,7,64]),
            #7*7*64
            #反卷积
            keras.layers.Conv2DTranspose(64,kernel_size=3,strides=2,padding='same',activation='selu'),
            #14*14*64
            keras.layers.Conv2DTranspose(64,kernel_size=3,strides=2,padding='same',activation='selu'),
            #28*28*64
            keras.layers.Conv2DTranspose(32,kernel_size=3,strides=2,padding='same',activation='selu'),
            #56*56*32
            keras.layers.Conv2DTranspose(16,kernel_size=3,strides=2,padding='same',activation='selu'),
            #112*112*16
            keras.layers.Conv2DTranspose(3,kernel_size=3,strides=2,padding='same',activation='tanh'),#使用tanh代替sigmoid
            #224*224*3
            keras.layers.Reshape([224,224,3])
            ])
            
    discriminator = keras.models.Sequential([
            keras.layers.Conv2D(128,kernel_size=3,padding='same',strides=2,activation='selu',input_shape=[224,224,3]),
            keras.layers.MaxPool2D(pool_size=2),
            #56*56*128
            keras.layers.Conv2D(64,kernel_size=3,padding='same',strides=2,activation='selu'),
            keras.layers.MaxPool2D(pool_size=2),
            #14*14*64
            keras.layers.Conv2D(32,kernel_size=3,padding='same',strides=2,activation='selu'),
            #7*7*32
            keras.layers.Flatten(),
            #dropout 0.4
            keras.layers.Dropout(0.4),
            keras.layers.Dense(512,activation='selu'),
            keras.layers.Dropout(0.4),
            keras.layers.Dense(64,activation='selu'),
            keras.layers.Dropout(0.4),
            #the last net
            keras.layers.Dense(1,activation='sigmoid')
            ])
    #gans network        
    gan = keras.models.Sequential([generator,discriminator])
    
    #compile the net
    discriminator.compile(loss="binary_crossentropy",optimizer='rmsprop')# metrics=['accuracy'])
    discriminator.trainable=False
    gan.compile(loss="binary_crossentropy",optimizer='rmsprop')# metrics=['accuracy'])
    
    #dataset
    #train_datas = read_and_decode(tfrecords_path)
    train_datas,iter = dataset_tfrecords(tfrecords_path,use_keras_fit=False)
    
    sess = tf.Session()
    sess.run(iter.initializer)
    
    #打开线程协调器
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=sess,coord=coord)
    
    generator,discriminator = gan.layers
    print("-----------------start---------------")
    for step in range(num_steps):
        try:
            #get the time
            start_time = time.time()
            #phase 1 - training the discriminator
            noise = np.random.normal(size=batch_size*coding_size).reshape([batch_size,coding_size])
            noise = np.cast[np.float32](noise)
            generated_images = generator.predict(noise)
            train_datas_ = sess.run(train_datas)
            x_fake_and_real = np.concatenate([generated_images,train_datas_[0]],axis = 0)#np.concatenate
            #千万不能再循环体内用tf.concat,不能用tf相关的函数在循环体内定义
            #否则内存会被耗尽,而且训练速度越来越慢
            y1 = np.array([[0.]]*batch_size+[[1.]]*len(train_datas_[0]))
            discriminator.trainable = True
            dis_loss = discriminator.train_on_batch(x_fake_and_real,y1)
            #将keras 的train_on_batch函数放在gan网络中是明智之举
            #phase 2 - training the generator
            noise = np.random.normal(size=batch_size*coding_size).reshape([batch_size,coding_size])
            noise = np.cast[np.float32](noise)
            y2 = np.array([[1.]]*batch_size)
            discriminator.trainable = False
            ad_loss = gan.train_on_batch(noise,y2)
            duration = time.time()-start_time
            if step % 5 == 0:
                #gan.save_weights('gan.h5')
                print("The step is %d,discriminator loss:%.3f,adversarial loss:%.3f"%(step,dis_loss,ad_loss),end=' ')
                print('%.2f s/step'%(duration))
            if step % 30 == 0 and step != 0:
                noise = np.random.normal(size=[1,coding_size])
                noise = np.cast[np.float32](noise)
                fake_image = generator.predict(noise,steps=1)
                #复原图像
                #1.乘以255后需要映射成uint8的类型
                #2.也可以保持[0,1]的float32类型,依然可以直接输出
                arr_img = np.array([fake_image],np.float32).reshape([224,224,3])*255
                arr_img = np.cast[np.uint8](arr_img)
                #保存为tfrecords用的是PIL.Image,即打开为RGB,所以在用cv显示时需要转换为BGR
                arr_img = cv2.cvtColor(arr_img,cv2.COLOR_RGB2BGR)
                cv2.imshow('fake image',arr_img)
                cv2.waitKey(1500)#show the fake image 1.5s
                cv2.destroyAllWindows()
        #在迭代完以后会抛出这个错误OutOfRangeError,所以需要将迭代器初始化
        except tf.errors.OutOfRangeError: 
            sess.run(iter.initializer)
            
    #关闭线程协调器
    coord.request_stop()
    coord.join(threads)
    #save the models tf2.0版本使用
    model_vision = '0001'
    model_name = 'gans'
    model_path = os.path.join(model_name,model_name)
    tf.saved_model.save(gan,model_path)
    #save the models tensorflow 1.13.1版本
    model_vision = '0001'
    gan.save_weights(model_vision)
    
def main():
    training_keras()
main()

至此便完成了简单的gan训练


参考资料

论文:《Generative Adversarial Networks》
参考源码:
https://github.com/eriklindernoren/Keras-GAN/blob/master/gan/gan.py
参考博客:
https://blog.csdn.net/u010138055/article/details/94441812

最后的话

深度学习、机器学习的学渣小硕一枚,刚起步,不足的地方还请大家多多指教。

  • 6
    点赞
  • 20
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值