CGAN是Conditional Generative Adversarial Nets的缩写,也称为条件生成对抗网络。条件生成对抗网络指的是在生成对抗网络中加入条件(condition),条件的作用是监督生成对抗网络。本篇博客通过简单代码搭建,向大家解析了条件生成对抗网络CGAN。
在开始解析CGAN代码之前,笔者想说的是,要理解CGAN,还请大家先明了CGAN的原理,笔者在下面提供一些笔者认为比较好的了解CGAN原理的链接:
(1) 直接进行论文阅读:https://arxiv.org/abs/1411.1784
(2) 可以翻阅站内的一篇博客,笔者认为写得很不错:Conditional Generative Adversarial Nets论文笔记
(3) 笔者也简单解析一下CGAN的原理,原理图如下(截图来自CGAN论文)
如上图所示,和原始的生成对抗网络相比,条件生成对抗网络CGAN在生成器的输入和判别器的输入中都加入了条件y。这个y可以是任何类型的数据(可以是类别标签,或者其他类型的数据等)。目的是有条件地监督生成器生成的数据,使得生成器生成结果的方式不是完全自由无监督的。
CGAN训练的目标函数如下图所示:
从上面的目标函数中可以看到,条件y不仅被送入了判别器的输入中,也被融入了生成器的输入中。下面,笔者就来解析CGAN的代码,首先还是列举一下笔者主要使用的工具和库。
(1) Python 3.5.2
(2) numpy
(3) Tensorflow 1.2
(4) argparse 用来解析命令行参数
(5) random 用来打乱输入顺序
(6) os 用来读取图片路径和文件名
(7) glob 用来读取图片路径和文件名
(8) cv2 用来读取图片
笔者搭建的CGAN代码分成4大部分,分别是:
(1) train.py 训练的主控程序
(2) image_reader.py 数据读取接口
(3) net.py 定义网络结构
(4) evaluate.py 测试的主控程序
其中,训练时使用到的文件是(1),(2),(3)项,测试时使用到的文件时(2),(3),(4)。
下面,笔者放出代码与注释:
首先是train.py文件中的代码:
- from __future__ import print_function
- import argparse
- from random import shuffle
- import random
- import os
- import sys
- import math
- import tensorflow as tf
- import glob
- import cv2
- from image_reader import *
- from net import *
- parser = argparse.ArgumentParser(description='')
- parser.add_argument("--snapshot_dir", default='./snapshots', help="path of snapshots") #保存模型的路径
- parser.add_argument("--out_dir", default='./train_out', help="path of train outputs") #训练时保存可视化输出的路径
- parser.add_argument("--image_size", type=int, default=256, help="load image size") #网络输入的尺度
- parser.add_argument("--random_seed", type=int, default=1234, help="random seed") #随机数种子
- parser.add_argument('--base_lr', type=float, default=0.0002, help='initial learning rate for adam') #学习率
- parser.add_argument('--epoch', dest='epoch', type=int, default=200, help='# of epoch') #训练的epoch数量
- parser.add_argument('--beta1', dest='beta1', type=float, default=0.5, help='momentum term of adam') #adam优化器的beta1参数
- parser.add_argument("--summary_pred_every", type=int, default=200, help="times to summary.") #训练中每过多少step保存训练日志(记录一下loss值)
- parser.add_argument("--write_pred_every", type=int, default=100, help="times to write.") #训练中每过多少step保存可视化结果
- parser.add_argument("--save_pred_every", type=int, default=5000, help="times to save.") #训练中每过多少step保存模型(可训练参数)
- parser.add_argument("--lamda_l1_weight", type=float, default=0.0, help="L1 lamda") #训练中L1_Loss前的乘数
- parser.add_argument("--lamda_gan_weight", type=float, default=1.0, help="GAN lamda") #训练中GAN_Loss前的乘数
- parser.add_argument("--train_picture_format", default='.png', help="format of training datas.") #网络训练输入的图片的格式(图片在CGAN中被当做条件)
- parser.add_argument("--train_label_format", default='.jpg', help="format of training labels.") #网络训练输入的标签的格式(标签在CGAN中被当做真样本)
- parser.add_argument("--train_picture_path", default='./dataset/train_picture/', help="path of training datas.") #网络训练输入的图片路径
- parser.add_argument("--train_label_path", default='./dataset/train_label/', help="path of training labels.") #网络训练输入的标签路径
- args = parser.parse_args() #用来解析命令行参数
- EPS = 1e-12 #EPS用于保证log函数里面的参数大于零
- def save(saver, sess, logdir, step): #保存模型的save函数
- model_name = 'model' #保存的模型名前缀
- checkpoint_path = os.path.join(logdir, model_name) #模型的保存路径与名称
- if not os.path.exists(logdir): #如果路径不存在即创建
- os.makedirs(logdir)
- saver.save(sess, checkpoint_path, global_step=step) #保存模型
- print('The checkpoint has been created.')
- def cv_inv_proc(img): #cv_inv_proc函数将读取图片时归一化的图片还原成原图
- img_rgb = (img + 1.) * 127.5
- return img_rgb.astype(np.float32) #返回bgr格式的图像,方便cv2写图像
- def get_write_picture(picture, gen_label, label, height, width): #get_write_picture函数得到训练过程中的可视化结果
- picture_image = cv_inv_proc(picture) #还原输入的图像
- gen_label_image = cv_inv_proc(gen_label[0]) #还原生成的样本
- label_image = cv_inv_proc(label) #还原真实的样本(标签)
- inv_picture_image = cv2.resize(picture_image, (width, height)) #还原图像的尺寸
- inv_gen_label_image = cv2.resize(gen_label_image, (width, height)) #还原生成的样本的尺寸
- inv_label_image = cv2.resize(label_image, (width, height)) #还原真实的样本的尺寸
- output = np.concatenate((inv_picture_image, inv_gen_label_image, inv_label_image), axis=1) #把他们拼起来
- return output
- def l1_loss(src, dst): #定义l1_loss
- return tf.reduce_mean(tf.abs(src - dst))
- def main(): #训练程序的主函数
- if not os.path.exists(args.snapshot_dir): #如果保存模型参数的文件夹不存在则创建
- os.makedirs(args.snapshot_dir)
- if not os.path.exists(args.out_dir): #如果保存训练中可视化输出的文件夹不存在则创建
- os.makedirs(args.out_dir)
- train_picture_list = glob.glob(os.path.join(args.train_picture_path, "*")) #得到训练输入图像路径名称列表
- tf.set_random_seed(args.random_seed) #初始一下随机数
- train_picture = tf.placeholder(tf.float32,shape=[1, args.image_size, args.image_size, 3],name='train_picture') #输入的训练图像
- train_label = tf.placeholder(tf.float32,shape=[1, args.image_size, args.image_size, 3],name='train_label') #输入的与训练图像匹配的标签
- gen_label = generator(image=train_picture, gf_dim=64, reuse=False, name='generator') #得到生成器的输出
- dis_real = discriminator(image=train_picture, targets=train_label, df_dim=64, reuse=False, name="discriminator") #判别器返回的对真实标签的判别结果
- dis_fake = discriminator(image=train_picture, targets=gen_label, df_dim=64, reuse=True, name="discriminator") #判别器返回的对生成(虚假的)标签判别结果
- gen_loss_GAN = tf.reduce_mean(-tf.log(dis_fake + EPS)) #计算生成器损失中的GAN_loss部分
- gen_loss_L1 = tf.reduce_mean(l1_loss(gen_label, train_label)) #计算生成器损失中的L1_loss部分
- gen_loss = gen_loss_GAN * args.lamda_gan_weight + gen_loss_L1 * args.lamda_l1_weight #计算生成器的loss
- dis_loss = tf.reduce_mean(-(tf.log(dis_real + EPS) + tf.log(1 - dis_fake + EPS))) #计算判别器的loss
- gen_loss_sum = tf.summary.scalar("gen_loss", gen_loss) #记录生成器loss的日志
- dis_loss_sum = tf.summary.scalar("dis_loss", dis_loss) #记录判别器loss的日志
- summary_writer = tf.summary.FileWriter(args.snapshot_dir, graph=tf.get_default_graph()) #日志记录器
- g_vars = [v for v in tf.trainable_variables() if 'generator' in v.name] #所有生成器的可训练参数
- d_vars = [v for v in tf.trainable_variables() if 'discriminator' in v.name] #所有判别器的可训练参数
- d_optim = tf.train.AdamOptimizer(args.base_lr, beta1=args.beta1) #判别器训练器
- g_optim = tf.train.AdamOptimizer(args.base_lr, beta1=args.beta1) #生成器训练器
- d_grads_and_vars = d_optim.compute_gradients(dis_loss, var_list=d_vars) #计算判别器参数梯度
- d_train = d_optim.apply_gradients(d_grads_and_vars) #更新判别器参数
- g_grads_and_vars = g_optim.compute_gradients(gen_loss, var_list=g_vars) #计算生成器参数梯度
- g_train = g_optim.apply_gradients(g_grads_and_vars) #更新生成器参数
- train_op = tf.group(d_train, g_train) #train_op表示了参数更新操作
- config = tf.ConfigProto()
- config.gpu_options.allow_growth = True #设定显存不超量使用
- sess = tf.Session(config=config) #新建会话层
- init = tf.global_variables_initializer() #参数初始化器
- sess.run(init) #初始化所有可训练参数
- saver = tf.train.Saver(var_list=tf.global_variables(), max_to_keep=50) #模型保存器
- counter = 0 #counter记录训练步数
- for epoch in range(args.epoch): #训练epoch数
- shuffle(train_picture_list) #每训练一个epoch,就打乱一下输入的顺序
- for step in range(len(train_picture_list)): #每个训练epoch中的训练step数
- counter += 1
- picture_name, _ = os.path.splitext(os.path.basename(train_picture_list[step])) #获取不包含路径和格式的输入图片名称
- #读取一张训练图片,一张训练标签,以及相应的高和宽
- picture_resize, label_resize, picture_height, picture_width = ImageReader(file_name=picture_name, picture_path=args.train_picture_path, label_path=args.train_label_path, picture_format = args.train_picture_format, label_format = args.train_label_format, size = args.image_size)
- batch_picture = np.expand_dims(np.array(picture_resize).astype(np.float32), axis = 0) #填充维度
- batch_label = np.expand_dims(np.array(label_resize).astype(np.float32), axis = 0) #填充维度
- feed_dict = { train_picture : batch_picture, train_label : batch_label } #构造feed_dict
- gen_loss_value, dis_loss_value, _ = sess.run([gen_loss, dis_loss, train_op], feed_dict=feed_dict) #得到每个step中的生成器和判别器loss
- if counter % args.save_pred_every == 0: #每过save_pred_every次保存模型
- save(saver, sess, args.snapshot_dir, counter)
- if counter % args.summary_pred_every == 0: #每过summary_pred_every次保存训练日志
- gen_loss_sum_value, discriminator_sum_value = sess.run([gen_loss_sum, dis_loss_sum], feed_dict=feed_dict)
- summary_writer.add_summary(gen_loss_sum_value, counter)
- summary_writer.add_summary(discriminator_sum_value, counter)
- if counter % args.write_pred_every == 0: #每过write_pred_every次写一下训练的可视化结果
- gen_label_value = sess.run(gen_label, feed_dict=feed_dict) #run出生成器的输出
- write_image = get_write_picture(picture_resize, gen_label_value, label_resize, picture_height, picture_width) #得到训练的可视化结果
- write_image_name = args.out_dir + "/out"+ str(counter) + ".png" #待保存的训练可视化结果路径与名称
- cv2.imwrite(write_image_name, write_image) #保存训练的可视化结果
- print('epoch {:d} step {:d} \t gen_loss = {:.3f}, dis_loss = {:.3f}'.format(epoch, step, gen_loss_value, dis_loss_value))
- if __name__ == '__main__':
- main()
然后是image_reader.py文件:
- import os
- import numpy as np
- import tensorflow as tf
- import cv2
- #读取图片的函数,接收六个参数
- #输入参数分别是图片名,图片路径,标签路径,图片格式,标签格式,需要调整的尺寸大小
- def ImageReader(file_name, picture_path, label_path, picture_format = ".png", label_format = ".jpg", size = 256):
- picture_name = picture_path + file_name + picture_format #得到图片名称和路径
- label_name = label_path + file_name + label_format #得到标签名称和路径
- picture = cv2.imread(picture_name, 1) #读取图片
- label = cv2.imread(label_name, 1) #读取标签
- height = picture.shape[0] #得到图片的高
- width = picture.shape[1] #得到图片的宽
- picture_resize_t = cv2.resize(picture, (size, size)) #调整图片的尺寸,改变成网络输入的大小
- picture_resize = picture_resize_t / 127.5 - 1. #归一化图片
- label_resize_t = cv2.resize(label, (size, size)) #调整标签的尺寸,改变成网络输入的大小
- label_resize = label_resize_t / 127.5 - 1. #归一化标签
- return picture_resize, label_resize, height, width #返回网络输入的图片,标签,还有原图片和标签的长宽
接着是net.py文件:
- import numpy as np
- import tensorflow as tf
- import math
- #构造可训练参数
- def make_var(name, shape, trainable = True):
- return tf.get_variable(name, shape, trainable = trainable)
- #定义卷积层
- def conv2d(input_, output_dim, kernel_size, stride, padding = "SAME", name = "conv2d", biased = False):
- input_dim = input_.get_shape()[-1]
- with tf.variable_scope(name):
- kernel = make_var(name = 'weights', shape=[kernel_size, kernel_size, input_dim, output_dim])
- output = tf.nn.conv2d(input_, kernel, [1, stride, stride, 1], padding = padding)
- if biased:
- biases = make_var(name = 'biases', shape = [output_dim])
- output = tf.nn.bias_add(output, biases)
- return output
- #定义空洞卷积层
- def atrous_conv2d(input_, output_dim, kernel_size, dilation, padding = "SAME", name = "atrous_conv2d", biased = False):
- input_dim = input_.get_shape()[-1]
- with tf.variable_scope(name):
- kernel = make_var(name = 'weights', shape = [kernel_size, kernel_size, input_dim, output_dim])
- output = tf.nn.atrous_conv2d(input_, kernel, dilation, padding = padding)
- if biased:
- biases = make_var(name = 'biases', shape = [output_dim])
- output = tf.nn.bias_add(output, biases)
- return output
- #定义反卷积层
- def deconv2d(input_, output_dim, kernel_size, stride, padding = "SAME", name = "deconv2d"):
- input_dim = input_.get_shape()[-1]
- input_height = int(input_.get_shape()[1])
- input_width = int(input_.get_shape()[2])
- with tf.variable_scope(name):
- kernel = make_var(name = 'weights', shape = [kernel_size, kernel_size, output_dim, input_dim])
- output = tf.nn.conv2d_transpose(input_, kernel, [1, input_height * 2, input_width * 2, output_dim], [1, 2, 2, 1], padding = "SAME")
- return output
- #定义batchnorm(批次归一化)层
- def batch_norm(input_, name="batch_norm"):
- with tf.variable_scope(name):
- input_dim = input_.get_shape()[-1]
- scale = tf.get_variable("scale", [input_dim], initializer=tf.random_normal_initializer(1.0, 0.02, dtype=tf.float32))
- offset = tf.get_variable("offset", [input_dim], initializer=tf.constant_initializer(0.0))
- mean, variance = tf.nn.moments(input_, axes=[1,2], keep_dims=True)
- epsilon = 1e-5
- inv = tf.rsqrt(variance + epsilon)
- normalized = (input_-mean)*inv
- output = scale*normalized + offset
- return output
- #定义lrelu激活层
- def lrelu(x, leak=0.2, name = "lrelu"):
- return tf.maximum(x, leak*x)
- #定义生成器,采用UNet架构,主要由8个卷积层和8个反卷积层组成
- def generator(image, gf_dim=64, reuse=False, name="generator"):
- input_dim = int(image.get_shape()[-1]) #获取输入通道
- dropout_rate = 0.5 #定义dropout的比例
- with tf.variable_scope(name):
- if reuse:
- tf.get_variable_scope().reuse_variables()
- else:
- assert tf.get_variable_scope().reuse is False
- #第一个卷积层,输出尺度[1, 128, 128, 64]
- e1 = batch_norm(conv2d(input_=image, output_dim=gf_dim, kernel_size=4, stride=2, name='g_e1_conv'), name='g_bn_e1')
- #第二个卷积层,输出尺度[1, 64, 64, 128]
- e2 = batch_norm(conv2d(input_=lrelu(e1), output_dim=gf_dim*2, kernel_size=4, stride=2, name='g_e2_conv'), name='g_bn_e2')
- #第三个卷积层,输出尺度[1, 32, 32, 256]
- e3 = batch_norm(conv2d(input_=lrelu(e2), output_dim=gf_dim*4, kernel_size=4, stride=2, name='g_e3_conv'), name='g_bn_e3')
- #第四个卷积层,输出尺度[1, 16, 16, 512]
- e4 = batch_norm(conv2d(input_=lrelu(e3), output_dim=gf_dim*8, kernel_size=4, stride=2, name='g_e4_conv'), name='g_bn_e4')
- #第五个卷积层,输出尺度[1, 8, 8, 512]
- e5 = batch_norm(conv2d(input_=lrelu(e4), output_dim=gf_dim*8, kernel_size=4, stride=2, name='g_e5_conv'), name='g_bn_e5')
- #第六个卷积层,输出尺度[1, 4, 4, 512]
- e6 = batch_norm(conv2d(input_=lrelu(e5), output_dim=gf_dim*8, kernel_size=4, stride=2, name='g_e6_conv'), name='g_bn_e6')
- #第七个卷积层,输出尺度[1, 2, 2, 512]
- e7 = batch_norm(conv2d(input_=lrelu(e6), output_dim=gf_dim*8, kernel_size=4, stride=2, name='g_e7_conv'), name='g_bn_e7')
- #第八个卷积层,输出尺度[1, 1, 1, 512]
- e8 = batch_norm(conv2d(input_=lrelu(e7), output_dim=gf_dim*8, kernel_size=4, stride=2, name='g_e8_conv'), name='g_bn_e8')
- #第一个反卷积层,输出尺度[1, 2, 2, 512]
- d1 = deconv2d(input_=tf.nn.relu(e8), output_dim=gf_dim*8, kernel_size=4, stride=2, name='g_d1')
- d1 = tf.nn.dropout(d1, dropout_rate) #随机扔掉一般的输出
- d1 = tf.concat([batch_norm(d1, name='g_bn_d1'), e7], 3)
- #第二个反卷积层,输出尺度[1, 4, 4, 512]
- d2 = deconv2d(input_=tf.nn.relu(d1), output_dim=gf_dim*8, kernel_size=4, stride=2, name='g_d2')
- d2 = tf.nn.dropout(d2, dropout_rate) #随机扔掉一般的输出
- d2 = tf.concat([batch_norm(d2, name='g_bn_d2'), e6], 3)
- #第三个反卷积层,输出尺度[1, 8, 8, 512]
- d3 = deconv2d(input_=tf.nn.relu(d2), output_dim=gf_dim*8, kernel_size=4, stride=2, name='g_d3')
- d3 = tf.nn.dropout(d3, dropout_rate) #随机扔掉一般的输出
- d3 = tf.concat([batch_norm(d3, name='g_bn_d3'), e5], 3)
- #第四个反卷积层,输出尺度[1, 16, 16, 512]
- d4 = deconv2d(input_=tf.nn.relu(d3), output_dim=gf_dim*8, kernel_size=4, stride=2, name='g_d4')
- d4 = tf.concat([batch_norm(d4, name='g_bn_d4'), e4], 3)
- #第五个反卷积层,输出尺度[1, 32, 32, 256]
- d5 = deconv2d(input_=tf.nn.relu(d4), output_dim=gf_dim*4, kernel_size=4, stride=2, name='g_d5')
- d5 = tf.concat([batch_norm(d5, name='g_bn_d5'), e3], 3)
- #第六个反卷积层,输出尺度[1, 64, 64, 128]
- d6 = deconv2d(input_=tf.nn.relu(d5), output_dim=gf_dim*2, kernel_size=4, stride=2, name='g_d6')
- d6 = tf.concat([batch_norm(d6, name='g_bn_d6'), e2], 3)
- #第七个反卷积层,输出尺度[1, 128, 128, 64]
- d7 = deconv2d(input_=tf.nn.relu(d6), output_dim=gf_dim, kernel_size=4, stride=2, name='g_d7')
- d7 = tf.concat([batch_norm(d7, name='g_bn_d7'), e1], 3)
- #第八个反卷积层,输出尺度[1, 256, 256, 3]
- d8 = deconv2d(input_=tf.nn.relu(d7), output_dim=input_dim, kernel_size=4, stride=2, name='g_d8')
- return tf.nn.tanh(d8)
- #定义判别器
- def discriminator(image, targets, df_dim=64, reuse=False, name="discriminator"):
- with tf.variable_scope(name):
- if reuse:
- tf.get_variable_scope().reuse_variables()
- else:
- assert tf.get_variable_scope().reuse is False
- dis_input = tf.concat([image, targets], 3)
- #第1个卷积模块,输出尺度: 1*128*128*64
- h0 = lrelu(conv2d(input_ = dis_input, output_dim = df_dim, kernel_size = 4, stride = 2, name='d_h0_conv'))
- #第2个卷积模块,输出尺度: 1*64*64*128
- h1 = lrelu(batch_norm(conv2d(input_ = h0, output_dim = df_dim*2, kernel_size = 4, stride = 2, name='d_h1_conv'), name='d_bn1'))
- #第3个卷积模块,输出尺度: 1*32*32*256
- h2 = lrelu(batch_norm(conv2d(input_ = h1, output_dim = df_dim*4, kernel_size = 4, stride = 2, name='d_h2_conv'), name='d_bn2'))
- #第4个卷积模块,输出尺度: 1*32*32*512
- h3 = lrelu(batch_norm(conv2d(input_ = h2, output_dim = df_dim*8, kernel_size = 4, stride = 1, name='d_h3_conv'), name='d_bn3'))
- #最后一个卷积模块,输出尺度: 1*32*32*1
- output = conv2d(input_ = h3, output_dim = 1, kernel_size = 4, stride = 1, name='d_h4_conv')
- dis_out = tf.sigmoid(output) #在输出之前经过sigmoid层,因为需要进行log运算
- return dis_out
上面就是训练所需的全部代码,大家可以看到,在net.py文件中。生成器使用UNet结构,在生成器和判别器中,image参数就是指的条件,并且在生成器的输入中,随机噪声被去掉了(仅仅输入了条件);在判别器的输入中,条件和待判别的图像被拼接(concat)了起来。
如果需要开启训练,可以调整train.py中的最后四个参数,根据自己的需求调整训练输入的图片和标签文件路径和相应的格式。另外,由于CGAN训练中需要匹配条件与判别图片,因此,训练读取的图片和标签名称应该是匹配的,在image_reader.py中也能看到,程序是按照同一个名称,去检索训练一个批次输入的图像和对应的标签。
下面是evaluate.py文件:
- import argparse
- import sys
- import math
- import tensorflow as tf
- import numpy as np
- import glob
- import cv2
- from image_reader import *
- from net import *
- parser = argparse.ArgumentParser(description='')
- parser.add_argument("--test_picture_path", default='./dataset/test_picture/', help="path of test datas.")#网络测试输入的图片路径
- parser.add_argument("--test_label_path", default='./dataset/test_label/', help="path of test datas.") #网络测试输入的标签路径
- parser.add_argument("--image_size", type=int, default=256, help="load image size") #网络输入的尺度
- parser.add_argument("--test_picture_format", default='.png', help="format of test pictures.") #网络测试输入的图片的格式
- parser.add_argument("--test_label_format", default='.jpg', help="format of test labels.") #网络测试时读取的标签的格式
- parser.add_argument("--snapshots", default='./snapshots/',help="Path of Snapshots") #读取训练好的模型参数的路径
- parser.add_argument("--out_dir", default='./test_output/',help="Output Folder") #保存网络测试输出图片的路径
- args = parser.parse_args() #用来解析命令行参数
- def cv_inv_proc(img): #cv_inv_proc函数将读取图片时归一化的图片还原成原图
- img_rgb = (img + 1.) * 127.5
- return img_rgb.astype(np.float32) #返回bgr格式的图像,方便cv2写图像
- def get_write_picture(picture, gen_label, label, height, width): #get_write_picture函数得到网络测试的结果
- picture_image = cv_inv_proc(picture) #还原输入的图像
- gen_label_image = cv_inv_proc(gen_label[0]) #还原生成的结果
- label_image = cv_inv_proc(label) #还原读取的标签
- inv_picture_image = cv2.resize(picture_image, (width, height)) #将输入图像还原到原大小
- inv_gen_label_image = cv2.resize(gen_label_image, (width, height)) #将生成的结果还原到原大小
- inv_label_image = cv2.resize(label_image, (width, height)) #将标签还原到原大小
- output = np.concatenate((inv_picture_image, inv_gen_label_image, inv_label_image), axis=1) #拼接得到输出结果
- return output
- def main():
- if not os.path.exists(args.out_dir): #如果保存测试结果的文件夹不存在则创建
- os.makedirs(args.out_dir)
- test_picture_list = glob.glob(os.path.join(args.test_picture_path, "*")) #得到测试输入图像路径名称列表
- test_picture = tf.placeholder(tf.float32, shape=[1, 256, 256, 3], name='test_picture') #测试输入的图像
- gen_label = generator(image=test_picture, gf_dim=64, reuse=False, name='generator') #得到生成器的生成结果
- restore_var = [v for v in tf.global_variables() if 'generator' in v.name] #需要载入的已训练的模型参数
- config = tf.ConfigProto()
- config.gpu_options.allow_growth = True #设定显存不超量使用
- sess = tf.Session(config=config) #建立会话层
- saver = tf.train.Saver(var_list=restore_var, max_to_keep=1) #导入模型参数时使用
- checkpoint = tf.train.latest_checkpoint(args.snapshots) #读取模型参数
- saver.restore(sess, checkpoint) #导入模型参数
- for step in range(len(test_picture_list)):
- picture_name, _ = os.path.splitext(os.path.basename(test_picture_list[step])) #得到一张网络测试的输入图像名字
- #读取一张测试图片,一张标签,以及相应的高和宽
- picture_resize, label_resize, picture_height, picture_width = ImageReader(file_name=picture_name,
- picture_path=args.test_picture_path,
- label_path=args.test_label_path,
- picture_format=args.test_picture_format,
- label_format=args.test_label_format,
- size=args.image_size)
- batch_picture = np.expand_dims(np.array(picture_resize).astype(np.float32), axis=0) #填充维度
- feed_dict = {test_picture: batch_picture} #构造feed_dict
- gen_label_value = sess.run(gen_label, feed_dict=feed_dict) #得到生成结果
- write_image = get_write_picture(picture_resize, gen_label_value, label_resize, picture_height, picture_width) #得到一张需要存的图像
- write_image_name = args.out_dir + picture_name + ".png" #为上述的图像构造保存路径与文件名
- cv2.imwrite(write_image_name, write_image) #保存测试结果
- print('step {:d}'.format(step))
- if __name__ == '__main__':
- main()
如果需要测试训练完毕的模型,相应地更改测试图片和标签输入路径和格式的四个参数即可,并设置读取模型权重的路径即可。
下面,笔者就以训练的填充轮廓生成建筑图片的例子为大家展示一下CGAN的效果:
首先是训练时的可视化输出图像,从左到右,第一张是网络的输入图片(条件),第二张是生成器生成的建筑图像,第三张是真实的建筑图像(标签)。
首先是训练200次的输出:
然后是训练5600次的输出:
然后是训练19000次的输出:
然后是训练36500次的输出:
然后是训练46700次的输出:
然后是训练65700次的输出:
然后是训练72400次的输出:
最后是训练96300次的输出:
下面展示一下训练的loss曲线:
生成器的loss曲线:
判别器的loss曲线:
最后展示一下在测试集上面的效果:
左边是输入的图像(条件),中间是生成的图像,右边是标签(真实的样本)。
上面就是在测试集上面的效果,读者朋友们可以从文章开头笔者放出的链接中下载数据集进行实验。
在train.py中,如果将lamda_l1_weight参数改成100,就是pix2pix的做法,笔者放了一些测试集的效果(训练有一些过拟合):
到这里,CGAN的模型搭建及解析就接近尾声了,很感谢Mehdi Mirza和Simon Osindero,为大家带来条件监督的生成对抗网络算法。CGAN还可以做很多有趣的事情,比如说这个有趣的工作:AI可能真的要代替插画师了……,项目链接https://make.girls.moe/#/,通过CGAN有条件地生成二次元萌妹。