详解GAN代码之搭建并详解CGAN代码

 CGAN是Conditional Generative Adversarial Nets的缩写,也称为条件生成对抗网络。条件生成对抗网络指的是在生成对抗网络中加入条件(condition),条件的作用是监督生成对抗网络。本篇博客通过简单代码搭建,向大家解析了条件生成对抗网络CGAN。

   在开始解析CGAN代码之前,笔者想说的是,要理解CGAN,还请大家先明了CGAN的原理,笔者在下面提供一些笔者认为比较好的了解CGAN原理的链接:

(1) 直接进行论文阅读:https://arxiv.org/abs/1411.1784

(2) 可以翻阅站内的一篇博客,笔者认为写得很不错:Conditional Generative Adversarial Nets论文笔记

(3) 笔者也简单解析一下CGAN的原理,原理图如下(截图来自CGAN论文)


   如上图所示,和原始的生成对抗网络相比,条件生成对抗网络CGAN在生成器的输入和判别器的输入中都加入了条件y。这个y可以是任何类型的数据(可以是类别标签,或者其他类型的数据等)。目的是有条件地监督生成器生成的数据,使得生成器生成结果的方式不是完全自由无监督的。

   CGAN训练的目标函数如下图所示:


   从上面的目标函数中可以看到,条件y不仅被送入了判别器的输入中,也被融入了生成器的输入中。下面,笔者就来解析CGAN的代码,首先还是列举一下笔者主要使用的工具和库。

(1) Python 3.5.2

(2) numpy

(3) Tensorflow 1.2

(4) argparse 用来解析命令行参数

(5) random 用来打乱输入顺序

(6) os 用来读取图片路径和文件名

(7) glob 用来读取图片路径和文件名

(8) cv2 用来读取图片

   笔者搭建的CGAN代码分成4大部分,分别是:

(1) train.py 训练的主控程序

(2) image_reader.py 数据读取接口

(3) net.py 定义网络结构

(4) evaluate.py 测试的主控程序

   其中,训练时使用到的文件是(1),(2),(3)项,测试时使用到的文件时(2),(3),(4)。

   下面,笔者放出代码与注释:

首先是train.py文件中的代码:

[python]  view plain  copy
  1. from __future__ import print_function  
  2.   
  3. import argparse  
  4. from random import shuffle  
  5. import random  
  6. import os  
  7. import sys  
  8. import math  
  9. import tensorflow as tf  
  10. import glob  
  11. import cv2  
  12.   
  13. from image_reader import *  
  14. from net import *  
  15.   
  16. parser = argparse.ArgumentParser(description='')  
  17.   
  18. parser.add_argument("--snapshot_dir", default='./snapshots', help="path of snapshots"#保存模型的路径  
  19. parser.add_argument("--out_dir", default='./train_out', help="path of train outputs"#训练时保存可视化输出的路径  
  20. parser.add_argument("--image_size", type=int, default=256, help="load image size"#网络输入的尺度  
  21. parser.add_argument("--random_seed", type=int, default=1234, help="random seed"#随机数种子  
  22. parser.add_argument('--base_lr', type=float, default=0.0002, help='initial learning rate for adam'#学习率  
  23. parser.add_argument('--epoch', dest='epoch', type=int, default=200, help='# of epoch')  #训练的epoch数量  
  24. parser.add_argument('--beta1', dest='beta1', type=float, default=0.5, help='momentum term of adam'#adam优化器的beta1参数  
  25. parser.add_argument("--summary_pred_every", type=int, default=200, help="times to summary."#训练中每过多少step保存训练日志(记录一下loss值)  
  26. parser.add_argument("--write_pred_every", type=int, default=100, help="times to write."#训练中每过多少step保存可视化结果  
  27. parser.add_argument("--save_pred_every", type=int, default=5000, help="times to save."#训练中每过多少step保存模型(可训练参数)  
  28. parser.add_argument("--lamda_l1_weight", type=float, default=0.0, help="L1 lamda"#训练中L1_Loss前的乘数  
  29. parser.add_argument("--lamda_gan_weight", type=float, default=1.0, help="GAN lamda"#训练中GAN_Loss前的乘数  
  30. parser.add_argument("--train_picture_format", default='.png', help="format of training datas."#网络训练输入的图片的格式(图片在CGAN中被当做条件)  
  31. parser.add_argument("--train_label_format", default='.jpg', help="format of training labels."#网络训练输入的标签的格式(标签在CGAN中被当做真样本)  
  32. parser.add_argument("--train_picture_path", default='./dataset/train_picture/', help="path of training datas."#网络训练输入的图片路径  
  33. parser.add_argument("--train_label_path", default='./dataset/train_label/', help="path of training labels."#网络训练输入的标签路径  
  34.   
  35. args = parser.parse_args() #用来解析命令行参数  
  36. EPS = 1e-12 #EPS用于保证log函数里面的参数大于零  
  37.   
  38. def save(saver, sess, logdir, step): #保存模型的save函数  
  39.    model_name = 'model' #保存的模型名前缀  
  40.    checkpoint_path = os.path.join(logdir, model_name) #模型的保存路径与名称  
  41.    if not os.path.exists(logdir): #如果路径不存在即创建  
  42.       os.makedirs(logdir)  
  43.    saver.save(sess, checkpoint_path, global_step=step) #保存模型  
  44.    print('The checkpoint has been created.')  
  45.   
  46. def cv_inv_proc(img): #cv_inv_proc函数将读取图片时归一化的图片还原成原图  
  47.     img_rgb = (img + 1.) * 127.5  
  48.     return img_rgb.astype(np.float32) #返回bgr格式的图像,方便cv2写图像  
  49.   
  50. def get_write_picture(picture, gen_label, label, height, width): #get_write_picture函数得到训练过程中的可视化结果  
  51.     picture_image = cv_inv_proc(picture) #还原输入的图像  
  52.     gen_label_image = cv_inv_proc(gen_label[0]) #还原生成的样本  
  53.     label_image = cv_inv_proc(label) #还原真实的样本(标签)  
  54.     inv_picture_image = cv2.resize(picture_image, (width, height)) #还原图像的尺寸  
  55.     inv_gen_label_image = cv2.resize(gen_label_image, (width, height)) #还原生成的样本的尺寸  
  56.     inv_label_image = cv2.resize(label_image, (width, height)) #还原真实的样本的尺寸  
  57.     output = np.concatenate((inv_picture_image, inv_gen_label_image, inv_label_image), axis=1#把他们拼起来  
  58.     return output  
  59.       
  60. def l1_loss(src, dst): #定义l1_loss  
  61.     return tf.reduce_mean(tf.abs(src - dst))  
  62.   
  63. def main(): #训练程序的主函数  
  64.     if not os.path.exists(args.snapshot_dir): #如果保存模型参数的文件夹不存在则创建  
  65.         os.makedirs(args.snapshot_dir)  
  66.     if not os.path.exists(args.out_dir): #如果保存训练中可视化输出的文件夹不存在则创建  
  67.         os.makedirs(args.out_dir)  
  68.     train_picture_list = glob.glob(os.path.join(args.train_picture_path, "*")) #得到训练输入图像路径名称列表  
  69.     tf.set_random_seed(args.random_seed) #初始一下随机数  
  70.     train_picture = tf.placeholder(tf.float32,shape=[1, args.image_size, args.image_size, 3],name='train_picture'#输入的训练图像  
  71.     train_label = tf.placeholder(tf.float32,shape=[1, args.image_size, args.image_size, 3],name='train_label'#输入的与训练图像匹配的标签  
  72.   
  73.     gen_label = generator(image=train_picture, gf_dim=64, reuse=False, name='generator'#得到生成器的输出  
  74.     dis_real = discriminator(image=train_picture, targets=train_label, df_dim=64, reuse=False, name="discriminator"#判别器返回的对真实标签的判别结果  
  75.     dis_fake = discriminator(image=train_picture, targets=gen_label, df_dim=64, reuse=True, name="discriminator"#判别器返回的对生成(虚假的)标签判别结果  
  76.   
  77.     gen_loss_GAN = tf.reduce_mean(-tf.log(dis_fake + EPS)) #计算生成器损失中的GAN_loss部分  
  78.     gen_loss_L1 = tf.reduce_mean(l1_loss(gen_label, train_label)) #计算生成器损失中的L1_loss部分  
  79.     gen_loss = gen_loss_GAN * args.lamda_gan_weight + gen_loss_L1 * args.lamda_l1_weight #计算生成器的loss  
  80.   
  81.     dis_loss = tf.reduce_mean(-(tf.log(dis_real + EPS) + tf.log(1 - dis_fake + EPS))) #计算判别器的loss  
  82.   
  83.     gen_loss_sum = tf.summary.scalar("gen_loss", gen_loss) #记录生成器loss的日志  
  84.     dis_loss_sum = tf.summary.scalar("dis_loss", dis_loss) #记录判别器loss的日志  
  85.   
  86.     summary_writer = tf.summary.FileWriter(args.snapshot_dir, graph=tf.get_default_graph()) #日志记录器  
  87.   
  88.     g_vars = [v for v in tf.trainable_variables() if 'generator' in v.name] #所有生成器的可训练参数  
  89.     d_vars = [v for v in tf.trainable_variables() if 'discriminator' in v.name] #所有判别器的可训练参数  
  90.   
  91.     d_optim = tf.train.AdamOptimizer(args.base_lr, beta1=args.beta1) #判别器训练器  
  92.     g_optim = tf.train.AdamOptimizer(args.base_lr, beta1=args.beta1) #生成器训练器  
  93.   
  94.     d_grads_and_vars = d_optim.compute_gradients(dis_loss, var_list=d_vars) #计算判别器参数梯度  
  95.     d_train = d_optim.apply_gradients(d_grads_and_vars) #更新判别器参数  
  96.     g_grads_and_vars = g_optim.compute_gradients(gen_loss, var_list=g_vars) #计算生成器参数梯度  
  97.     g_train = g_optim.apply_gradients(g_grads_and_vars) #更新生成器参数  
  98.   
  99.     train_op = tf.group(d_train, g_train) #train_op表示了参数更新操作  
  100.     config = tf.ConfigProto()  
  101.     config.gpu_options.allow_growth = True #设定显存不超量使用  
  102.     sess = tf.Session(config=config) #新建会话层  
  103.     init = tf.global_variables_initializer() #参数初始化器  
  104.   
  105.     sess.run(init) #初始化所有可训练参数  
  106.   
  107.     saver = tf.train.Saver(var_list=tf.global_variables(), max_to_keep=50#模型保存器  
  108.   
  109.     counter = 0 #counter记录训练步数  
  110.   
  111.     for epoch in range(args.epoch): #训练epoch数  
  112.         shuffle(train_picture_list) #每训练一个epoch,就打乱一下输入的顺序  
  113.         for step in range(len(train_picture_list)): #每个训练epoch中的训练step数  
  114.             counter += 1  
  115.             picture_name, _ = os.path.splitext(os.path.basename(train_picture_list[step])) #获取不包含路径和格式的输入图片名称  
  116.         #读取一张训练图片,一张训练标签,以及相应的高和宽  
  117.             picture_resize, label_resize, picture_height, picture_width = ImageReader(file_name=picture_name, picture_path=args.train_picture_path, label_path=args.train_label_path, picture_format = args.train_picture_format, label_format = args.train_label_format, size = args.image_size)  
  118.             batch_picture = np.expand_dims(np.array(picture_resize).astype(np.float32), axis = 0#填充维度  
  119.             batch_label = np.expand_dims(np.array(label_resize).astype(np.float32), axis = 0#填充维度  
  120.             feed_dict = { train_picture : batch_picture, train_label : batch_label } #构造feed_dict  
  121.             gen_loss_value, dis_loss_value, _ = sess.run([gen_loss, dis_loss, train_op], feed_dict=feed_dict) #得到每个step中的生成器和判别器loss  
  122.             if counter % args.save_pred_every == 0#每过save_pred_every次保存模型  
  123.                 save(saver, sess, args.snapshot_dir, counter)  
  124.             if counter % args.summary_pred_every == 0#每过summary_pred_every次保存训练日志  
  125.                 gen_loss_sum_value, discriminator_sum_value = sess.run([gen_loss_sum, dis_loss_sum], feed_dict=feed_dict)  
  126.                 summary_writer.add_summary(gen_loss_sum_value, counter)  
  127.                 summary_writer.add_summary(discriminator_sum_value, counter)  
  128.             if counter % args.write_pred_every == 0#每过write_pred_every次写一下训练的可视化结果  
  129.                 gen_label_value = sess.run(gen_label, feed_dict=feed_dict) #run出生成器的输出  
  130.                 write_image = get_write_picture(picture_resize, gen_label_value, label_resize, picture_height, picture_width) #得到训练的可视化结果  
  131.                 write_image_name = args.out_dir + "/out"+ str(counter) + ".png" #待保存的训练可视化结果路径与名称  
  132.                 cv2.imwrite(write_image_name, write_image) #保存训练的可视化结果  
  133.             print('epoch {:d} step {:d} \t gen_loss = {:.3f}, dis_loss = {:.3f}'.format(epoch, step, gen_loss_value, dis_loss_value))  
  134.       
  135. if __name__ == '__main__':  
  136.     main()  

然后是image_reader.py文件:

[python]  view plain  copy
  1. import os  
  2. import numpy as np  
  3. import tensorflow as tf  
  4. import cv2  
  5.   
  6. #读取图片的函数,接收六个参数  
  7. #输入参数分别是图片名,图片路径,标签路径,图片格式,标签格式,需要调整的尺寸大小  
  8. def ImageReader(file_name, picture_path, label_path, picture_format = ".png", label_format = ".jpg", size = 256):  
  9.     picture_name = picture_path + file_name + picture_format #得到图片名称和路径  
  10.     label_name = label_path + file_name + label_format #得到标签名称和路径  
  11.     picture = cv2.imread(picture_name, 1#读取图片  
  12.     label = cv2.imread(label_name, 1#读取标签  
  13.     height = picture.shape[0#得到图片的高  
  14.     width = picture.shape[1#得到图片的宽  
  15.     picture_resize_t = cv2.resize(picture, (size, size)) #调整图片的尺寸,改变成网络输入的大小  
  16.     picture_resize = picture_resize_t / 127.5 - 1. #归一化图片  
  17.     label_resize_t = cv2.resize(label, (size, size)) #调整标签的尺寸,改变成网络输入的大小  
  18.     label_resize = label_resize_t / 127.5 - 1. #归一化标签  
  19.     return picture_resize, label_resize, height, width #返回网络输入的图片,标签,还有原图片和标签的长宽  

接着是net.py文件:

[python]  view plain  copy
  1. import numpy as np  
  2. import tensorflow as tf  
  3. import math  
  4.   
  5. #构造可训练参数  
  6. def make_var(name, shape, trainable = True):  
  7.     return tf.get_variable(name, shape, trainable = trainable)  
  8.   
  9. #定义卷积层  
  10. def conv2d(input_, output_dim, kernel_size, stride, padding = "SAME", name = "conv2d", biased = False):  
  11.     input_dim = input_.get_shape()[-1]  
  12.     with tf.variable_scope(name):  
  13.         kernel = make_var(name = 'weights', shape=[kernel_size, kernel_size, input_dim, output_dim])  
  14.         output = tf.nn.conv2d(input_, kernel, [1, stride, stride, 1], padding = padding)  
  15.         if biased:  
  16.             biases = make_var(name = 'biases', shape = [output_dim])  
  17.             output = tf.nn.bias_add(output, biases)  
  18.         return output  
  19.   
  20. #定义空洞卷积层  
  21. def atrous_conv2d(input_, output_dim, kernel_size, dilation, padding = "SAME", name = "atrous_conv2d", biased = False):  
  22.     input_dim = input_.get_shape()[-1]  
  23.     with tf.variable_scope(name):  
  24.         kernel = make_var(name = 'weights', shape = [kernel_size, kernel_size, input_dim, output_dim])  
  25.         output = tf.nn.atrous_conv2d(input_, kernel, dilation, padding = padding)  
  26.         if biased:  
  27.             biases = make_var(name = 'biases', shape = [output_dim])  
  28.             output = tf.nn.bias_add(output, biases)  
  29.         return output  
  30.   
  31. #定义反卷积层  
  32. def deconv2d(input_, output_dim, kernel_size, stride, padding = "SAME", name = "deconv2d"):  
  33.     input_dim = input_.get_shape()[-1]  
  34.     input_height = int(input_.get_shape()[1])  
  35.     input_width = int(input_.get_shape()[2])  
  36.     with tf.variable_scope(name):  
  37.         kernel = make_var(name = 'weights', shape = [kernel_size, kernel_size, output_dim, input_dim])  
  38.         output = tf.nn.conv2d_transpose(input_, kernel, [1, input_height * 2, input_width * 2, output_dim], [1221], padding = "SAME")  
  39.         return output  
  40.   
  41. #定义batchnorm(批次归一化)层  
  42. def batch_norm(input_, name="batch_norm"):  
  43.     with tf.variable_scope(name):  
  44.         input_dim = input_.get_shape()[-1]  
  45.         scale = tf.get_variable("scale", [input_dim], initializer=tf.random_normal_initializer(1.00.02, dtype=tf.float32))  
  46.         offset = tf.get_variable("offset", [input_dim], initializer=tf.constant_initializer(0.0))  
  47.         mean, variance = tf.nn.moments(input_, axes=[1,2], keep_dims=True)  
  48.         epsilon = 1e-5  
  49.         inv = tf.rsqrt(variance + epsilon)  
  50.         normalized = (input_-mean)*inv  
  51.         output = scale*normalized + offset  
  52.         return output  
  53.   
  54. #定义lrelu激活层  
  55. def lrelu(x, leak=0.2, name = "lrelu"):  
  56.     return tf.maximum(x, leak*x)  
  57.   
  58. #定义生成器,采用UNet架构,主要由8个卷积层和8个反卷积层组成  
  59. def generator(image, gf_dim=64, reuse=False, name="generator"):  
  60.     input_dim = int(image.get_shape()[-1]) #获取输入通道  
  61.     dropout_rate = 0.5 #定义dropout的比例  
  62.     with tf.variable_scope(name):  
  63.         if reuse:  
  64.             tf.get_variable_scope().reuse_variables()  
  65.         else:  
  66.             assert tf.get_variable_scope().reuse is False  
  67.     #第一个卷积层,输出尺度[1, 128, 128, 64]  
  68.         e1 = batch_norm(conv2d(input_=image, output_dim=gf_dim, kernel_size=4, stride=2, name='g_e1_conv'), name='g_bn_e1')  
  69.     #第二个卷积层,输出尺度[1, 64, 64, 128]  
  70.         e2 = batch_norm(conv2d(input_=lrelu(e1), output_dim=gf_dim*2, kernel_size=4, stride=2, name='g_e2_conv'), name='g_bn_e2')  
  71.     #第三个卷积层,输出尺度[1, 32, 32, 256]  
  72.         e3 = batch_norm(conv2d(input_=lrelu(e2), output_dim=gf_dim*4, kernel_size=4, stride=2, name='g_e3_conv'), name='g_bn_e3')  
  73.     #第四个卷积层,输出尺度[1, 16, 16, 512]  
  74.         e4 = batch_norm(conv2d(input_=lrelu(e3), output_dim=gf_dim*8, kernel_size=4, stride=2, name='g_e4_conv'), name='g_bn_e4')  
  75.     #第五个卷积层,输出尺度[1, 8, 8, 512]  
  76.         e5 = batch_norm(conv2d(input_=lrelu(e4), output_dim=gf_dim*8, kernel_size=4, stride=2, name='g_e5_conv'), name='g_bn_e5')  
  77.     #第六个卷积层,输出尺度[1, 4, 4, 512]  
  78.         e6 = batch_norm(conv2d(input_=lrelu(e5), output_dim=gf_dim*8, kernel_size=4, stride=2, name='g_e6_conv'), name='g_bn_e6')  
  79.     #第七个卷积层,输出尺度[1, 2, 2, 512]  
  80.         e7 = batch_norm(conv2d(input_=lrelu(e6), output_dim=gf_dim*8, kernel_size=4, stride=2, name='g_e7_conv'), name='g_bn_e7')  
  81.     #第八个卷积层,输出尺度[1, 1, 1, 512]  
  82.         e8 = batch_norm(conv2d(input_=lrelu(e7), output_dim=gf_dim*8, kernel_size=4, stride=2, name='g_e8_conv'), name='g_bn_e8')  
  83.   
  84.     #第一个反卷积层,输出尺度[1, 2, 2, 512]  
  85.         d1 = deconv2d(input_=tf.nn.relu(e8), output_dim=gf_dim*8, kernel_size=4, stride=2, name='g_d1')  
  86.         d1 = tf.nn.dropout(d1, dropout_rate) #随机扔掉一般的输出  
  87.         d1 = tf.concat([batch_norm(d1, name='g_bn_d1'), e7], 3)  
  88.     #第二个反卷积层,输出尺度[1, 4, 4, 512]  
  89.         d2 = deconv2d(input_=tf.nn.relu(d1), output_dim=gf_dim*8, kernel_size=4, stride=2, name='g_d2')  
  90.         d2 = tf.nn.dropout(d2, dropout_rate) #随机扔掉一般的输出  
  91.         d2 = tf.concat([batch_norm(d2, name='g_bn_d2'), e6], 3)  
  92.     #第三个反卷积层,输出尺度[1, 8, 8, 512]  
  93.         d3 = deconv2d(input_=tf.nn.relu(d2), output_dim=gf_dim*8, kernel_size=4, stride=2, name='g_d3')  
  94.         d3 = tf.nn.dropout(d3, dropout_rate) #随机扔掉一般的输出  
  95.         d3 = tf.concat([batch_norm(d3, name='g_bn_d3'), e5], 3)  
  96.     #第四个反卷积层,输出尺度[1, 16, 16, 512]  
  97.         d4 = deconv2d(input_=tf.nn.relu(d3), output_dim=gf_dim*8, kernel_size=4, stride=2, name='g_d4')  
  98.         d4 = tf.concat([batch_norm(d4, name='g_bn_d4'), e4], 3)  
  99.     #第五个反卷积层,输出尺度[1, 32, 32, 256]  
  100.         d5 = deconv2d(input_=tf.nn.relu(d4), output_dim=gf_dim*4, kernel_size=4, stride=2, name='g_d5')  
  101.         d5 = tf.concat([batch_norm(d5, name='g_bn_d5'), e3], 3)  
  102.     #第六个反卷积层,输出尺度[1, 64, 64, 128]  
  103.         d6 = deconv2d(input_=tf.nn.relu(d5), output_dim=gf_dim*2, kernel_size=4, stride=2, name='g_d6')  
  104.         d6 = tf.concat([batch_norm(d6, name='g_bn_d6'), e2], 3)  
  105.     #第七个反卷积层,输出尺度[1, 128, 128, 64]  
  106.         d7 = deconv2d(input_=tf.nn.relu(d6), output_dim=gf_dim, kernel_size=4, stride=2, name='g_d7')  
  107.         d7 = tf.concat([batch_norm(d7, name='g_bn_d7'), e1], 3)  
  108.     #第八个反卷积层,输出尺度[1, 256, 256, 3]  
  109.         d8 = deconv2d(input_=tf.nn.relu(d7), output_dim=input_dim, kernel_size=4, stride=2, name='g_d8')  
  110.         return tf.nn.tanh(d8)  
  111.   
  112. #定义判别器  
  113. def discriminator(image, targets, df_dim=64, reuse=False, name="discriminator"):  
  114.     with tf.variable_scope(name):  
  115.         if reuse:  
  116.             tf.get_variable_scope().reuse_variables()  
  117.         else:  
  118.             assert tf.get_variable_scope().reuse is False  
  119.         dis_input = tf.concat([image, targets], 3)  
  120.     #第1个卷积模块,输出尺度: 1*128*128*64  
  121.         h0 = lrelu(conv2d(input_ = dis_input, output_dim = df_dim, kernel_size = 4, stride = 2, name='d_h0_conv'))  
  122.     #第2个卷积模块,输出尺度: 1*64*64*128  
  123.         h1 = lrelu(batch_norm(conv2d(input_ = h0, output_dim = df_dim*2, kernel_size = 4, stride = 2, name='d_h1_conv'), name='d_bn1'))  
  124.     #第3个卷积模块,输出尺度: 1*32*32*256  
  125.         h2 = lrelu(batch_norm(conv2d(input_ = h1, output_dim = df_dim*4, kernel_size = 4, stride = 2, name='d_h2_conv'), name='d_bn2'))  
  126.     #第4个卷积模块,输出尺度: 1*32*32*512  
  127.         h3 = lrelu(batch_norm(conv2d(input_ = h2, output_dim = df_dim*8, kernel_size = 4, stride = 1, name='d_h3_conv'), name='d_bn3'))  
  128.     #最后一个卷积模块,输出尺度: 1*32*32*1  
  129.         output = conv2d(input_ = h3, output_dim = 1, kernel_size = 4, stride = 1, name='d_h4_conv')  
  130.         dis_out = tf.sigmoid(output) #在输出之前经过sigmoid层,因为需要进行log运算  
  131.         return dis_out  

   上面就是训练所需的全部代码,大家可以看到,在net.py文件中。生成器使用UNet结构,在生成器和判别器中,image参数就是指的条件,并且在生成器的输入中,随机噪声被去掉了(仅仅输入了条件);在判别器的输入中,条件和待判别的图像被拼接(concat)了起来。

   如果需要开启训练,可以调整train.py中的最后四个参数,根据自己的需求调整训练输入的图片和标签文件路径和相应的格式。另外,由于CGAN训练中需要匹配条件与判别图片,因此,训练读取的图片和标签名称应该是匹配的,在image_reader.py中也能看到,程序是按照同一个名称,去检索训练一个批次输入的图像和对应的标签。

下面是evaluate.py文件:

[python]  view plain  copy
  1. import argparse  
  2. import sys  
  3. import math  
  4. import tensorflow as tf  
  5. import numpy as np  
  6. import glob  
  7. import cv2  
  8.   
  9. from image_reader import *  
  10. from net import *  
  11.   
  12. parser = argparse.ArgumentParser(description='')  
  13.   
  14. parser.add_argument("--test_picture_path", default='./dataset/test_picture/', help="path of test datas.")#网络测试输入的图片路径  
  15. parser.add_argument("--test_label_path", default='./dataset/test_label/', help="path of test datas."#网络测试输入的标签路径  
  16. parser.add_argument("--image_size", type=int, default=256, help="load image size"#网络输入的尺度  
  17. parser.add_argument("--test_picture_format", default='.png', help="format of test pictures."#网络测试输入的图片的格式  
  18. parser.add_argument("--test_label_format", default='.jpg', help="format of test labels."#网络测试时读取的标签的格式  
  19. parser.add_argument("--snapshots", default='./snapshots/',help="Path of Snapshots"#读取训练好的模型参数的路径  
  20. parser.add_argument("--out_dir", default='./test_output/',help="Output Folder"#保存网络测试输出图片的路径  
  21.   
  22. args = parser.parse_args() #用来解析命令行参数  
  23.   
  24.   
  25. def cv_inv_proc(img): #cv_inv_proc函数将读取图片时归一化的图片还原成原图  
  26.     img_rgb = (img + 1.) * 127.5  
  27.     return img_rgb.astype(np.float32) #返回bgr格式的图像,方便cv2写图像  
  28.   
  29.   
  30. def get_write_picture(picture, gen_label, label, height, width): #get_write_picture函数得到网络测试的结果  
  31.     picture_image = cv_inv_proc(picture) #还原输入的图像  
  32.     gen_label_image = cv_inv_proc(gen_label[0]) #还原生成的结果  
  33.     label_image = cv_inv_proc(label) #还原读取的标签  
  34.     inv_picture_image = cv2.resize(picture_image, (width, height)) #将输入图像还原到原大小  
  35.     inv_gen_label_image = cv2.resize(gen_label_image, (width, height)) #将生成的结果还原到原大小  
  36.     inv_label_image = cv2.resize(label_image, (width, height)) #将标签还原到原大小  
  37.     output = np.concatenate((inv_picture_image, inv_gen_label_image, inv_label_image), axis=1#拼接得到输出结果  
  38.     return output  
  39.   
  40.   
  41. def main():  
  42.     if not os.path.exists(args.out_dir): #如果保存测试结果的文件夹不存在则创建  
  43.         os.makedirs(args.out_dir)  
  44.   
  45.     test_picture_list = glob.glob(os.path.join(args.test_picture_path, "*")) #得到测试输入图像路径名称列表  
  46.     test_picture = tf.placeholder(tf.float32, shape=[12562563], name='test_picture'#测试输入的图像  
  47.   
  48.     gen_label = generator(image=test_picture, gf_dim=64, reuse=False, name='generator'#得到生成器的生成结果  
  49.   
  50.     restore_var = [v for v in tf.global_variables() if 'generator' in v.name] #需要载入的已训练的模型参数  
  51.   
  52.     config = tf.ConfigProto()  
  53.     config.gpu_options.allow_growth = True #设定显存不超量使用  
  54.     sess = tf.Session(config=config) #建立会话层  
  55.   
  56.     saver = tf.train.Saver(var_list=restore_var, max_to_keep=1#导入模型参数时使用  
  57.     checkpoint = tf.train.latest_checkpoint(args.snapshots) #读取模型参数  
  58.     saver.restore(sess, checkpoint) #导入模型参数  
  59.   
  60.     for step in range(len(test_picture_list)):  
  61.         picture_name, _ = os.path.splitext(os.path.basename(test_picture_list[step])) #得到一张网络测试的输入图像名字  
  62.     #读取一张测试图片,一张标签,以及相应的高和宽  
  63.         picture_resize, label_resize, picture_height, picture_width = ImageReader(file_name=picture_name,  
  64.                                                                                   picture_path=args.test_picture_path,  
  65.                                                                                   label_path=args.test_label_path,  
  66.                                                                                   picture_format=args.test_picture_format,  
  67.                                                                                   label_format=args.test_label_format,  
  68.                                                                                   size=args.image_size)  
  69.         batch_picture = np.expand_dims(np.array(picture_resize).astype(np.float32), axis=0#填充维度  
  70.         feed_dict = {test_picture: batch_picture} #构造feed_dict  
  71.         gen_label_value = sess.run(gen_label, feed_dict=feed_dict) #得到生成结果  
  72.         write_image = get_write_picture(picture_resize, gen_label_value, label_resize, picture_height, picture_width) #得到一张需要存的图像  
  73.         write_image_name = args.out_dir + picture_name + ".png" #为上述的图像构造保存路径与文件名  
  74.         cv2.imwrite(write_image_name, write_image) #保存测试结果  
  75.         print('step {:d}'.format(step))  
  76.   
  77. if __name__ == '__main__':  
  78.     main()  

   如果需要测试训练完毕的模型,相应地更改测试图片和标签输入路径和格式的四个参数即可,并设置读取模型权重的路径即可。

   下面,笔者就以训练的填充轮廓生成建筑图片的例子为大家展示一下CGAN的效果:

   首先是训练时的可视化输出图像,从左到右,第一张是网络的输入图片(条件),第二张是生成器生成的建筑图像,第三张是真实的建筑图像(标签)。

首先是训练200次的输出:


然后是训练5600次的输出:


然后是训练19000次的输出:


然后是训练36500次的输出:


然后是训练46700次的输出:


然后是训练65700次的输出:


然后是训练72400次的输出:


最后是训练96300次的输出:


下面展示一下训练的loss曲线:

生成器的loss曲线:


判别器的loss曲线:


最后展示一下在测试集上面的效果:

左边是输入的图像(条件),中间是生成的图像,右边是标签(真实的样本)。









   上面就是在测试集上面的效果,读者朋友们可以从文章开头笔者放出的链接中下载数据集进行实验。

   在train.py中,如果将lamda_l1_weight参数改成100,就是pix2pix的做法,笔者放了一些测试集的效果(训练有一些过拟合):




   到这里,CGAN的模型搭建及解析就接近尾声了,很感谢Mehdi Mirza和Simon Osindero,为大家带来条件监督的生成对抗网络算法。CGAN还可以做很多有趣的事情,比如说这个有趣的工作:AI可能真的要代替插画师了……,项目链接https://make.girls.moe/#/,通过CGAN有条件地生成二次元萌妹。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值