GAN:
生成对抗网络(Generative Adversarial Networks [4])主要由生成器 (generator) 和判别器 (discriminator)组成。它的原理也比较清晰,generator 负责输入随机噪声z,输出一个图片 G(z) ,而真实样本x,判别器 D 则用尽全力希望把真实样本和虚假样本区分开来。而 G 则希望产生的 G(z) 以假乱真,欺骗判别器,让其判断不出来。从而有了这么一种对抗的关系。
gan的缺点:
1.不适合处理离散形式的数据 例如——文本
2.训练不稳定
3.梯度消失
4.模式崩溃
常见的gan:
DCGAN:2、3(batchnorm)
WGAN:2(wassertein距离代替JS散度)、4
WGAN-GP:完爆以上的GAN
LSGAN:使用了最小二乘损失函数代替了GAN的损失函数,缓解了GAN训练不稳定和生成图像质量差多样性不足的问题
DCGAN:
DCGAN原理和 GAN 是一样的,DCGAN 可以理解为 GAN 和 CNN 的结合,同时 GAN 其实并不好收敛,DCGAN 在网络收敛上做了一些改进工作。比如,G 网络中采用 transposed convolutional layer 进行上采样,D 加入 stride conv 替代 pooling。G采用ReLU,最后一层采用tanh。D采用LeakyReLU。
1.cgan--mnist(手写数字生成)
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import os
#数据输入
mnist = input_data.read_data_sets('mnist/', one_hot=True)
mb_size = 64
Z_dim = 100
X_dim = mnist.train.images.shape[1]
y_dim = mnist.train.labels.shape[1]
h_dim = 128
#返回随机值
def xavier_init(size):
in_dim = size[0]
xavier_stddev = 1. / tf.sqrt(in_dim / 2.)
return tf.random_normal(shape=size, stddev=xavier_stddev)
#D网络,这里是一个简单的神经网络,x是输入图片向量,y是相应的label
def discriminator(x, y):
inputs = tf.concat(axis=1, values=[x, y])
D_h1 = tf.nn.relu(tf.matmul(inputs, D_W1) + D_b1)
D_logit = tf.matmul(D_h1, D_W2) + D_b2
D_prob = tf.nn.sigmoid(D_logit)
return D_prob, D_logit
#G网络
def generator(z, y):
inputs = tf.concat(axis=1, values=[z, y])
G_h1 = tf.nn.relu(tf.matmul(inputs, G_W1) + G_b1)
G_log_prob = tf.matmul(G_h1, G_W2) + G_b2
G_prob = tf.nn.sigmoid(G_log_prob)
return G_prob
#噪声产生的函数
def sample_Z(m, n):
return np.random.uniform(-1., 1., size=[m, n])
def plot(samples):
fig = plt.figure(figsize=(4, 4))
gs = gridspec.GridSpec(4, 4)
gs.update(wspace=0.05, hspace=0.05)
for i, sample in enumerate(samples):
ax = plt.subplot(gs[i])
plt.axis('off')
ax.set_xticklabels([])
ax.set_yticklabels([])
ax.set_aspect('equal')
plt.imshow(sample.reshape(28, 28), cmap='Greys_r')
return fig
#X代表输入图片,应该是28*28,但是这里没有使用CNN,y是相应的label
""" Discriminator Net model """
X = tf.placeholder(tf.float32, shape=[None, 784])
y = tf.placeholder(tf.float32, shape=[None, y_dim])
#权重,CGAN的输入是将图片输入与label concat起来,所以权重维度为784+10
D_W1 = tf.Variable(xavier_init([X_dim + y_dim, h_dim]))
D_b1 = tf.Variable(tf.zeros(shape=[h_dim]))
#第二层有h_dim个节点
D_W2 = tf.Variable(xavier_init([h_dim, 1]))
D_b2 = tf.Variable(tf.zeros(shape=[1]))
theta_D = [D_W1, D_W2, D_b1, D_b2]
#G网络参数,输入维度为Z_dim+y_dim,中间层有h_dim个节点,输出X_dim的数据
""" Generator Net model """
Z = tf.placeholder(tf.float32, shape=[None, Z_dim])
#权重
G_W1 = tf.Variable(xavier_init([Z_dim + y_dim, h_dim]))
G_b1 = tf.Variable(tf.zeros(shape=[h_dim]))
G_W2 = tf.Variable(xavier_init([h_dim, X_dim]))
G_b2 = tf.Variable(tf.zeros(shape=[X_dim]))
theta_G = [G_W1, G_W2, G_b1, G_b2]
#生成网络,基本和GAN一致
G_sample = generator(Z, y)
D_real, D_logit_real = discriminator(X, y)
D_fake, D_logit_fake = discriminator(G_sample, y)
#优化式
D_loss_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=D_logit_real, labels=tf.ones_like(D_logit_real)))
D_loss_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=D_logit_fake, labels=tf.zeros_like(D_logit_fake)))
D_loss = D_loss_real + D_loss_fake
G_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=D_logit_fake, labels=tf.ones_like(D_logit_fake)))
#训练
D_solver = tf.train.AdamOptimizer().minimize(D_loss, var_list=theta_D)
G_solver = tf.train.AdamOptimizer().minimize(G_loss, var_list=theta_G)
sess = tf.Session()
sess.run(tf.global_variables_initializer())
#输出图片在out文件夹
if not os.path.exists('out/'):
os.makedirs('out/')
i = 0
for it in range(1000000):
#mb_size是网络训练时用的Batchsize,为100
X_mb, y_mb = mnist.train.next_batch(mb_size)
#Z_dim是noise的维度,为100
Z_sample = sample_Z(mb_size, Z_dim)
#交替最小化训练
_, D_loss_curr = sess.run([D_solver, D_loss], feed_dict={X: X_mb, Z: Z_sample, y:y_mb})
_, G_loss_curr = sess.run([G_solver, G_loss], feed_dict={Z: Z_sample, y:y_mb})
#输出训练时的参数
if it % 1000 == 0:
print('Iter: {}'.format(it))
print('D loss: {:.4}'. format(D_loss_curr))
print('G_loss: {:.4}'.format(G_loss_curr))
# 测试
if it % 1000 == 0:
#n_sample 是G网络测试用的Batchsize,为16,所以输出的png图有16张
n_sample = 16
Z_sample = sample_Z(n_sample, Z_dim)#输入的噪声,尺寸为batchsize*noise维度
y_sample = np.zeros(shape=[n_sample, y_dim])#输入的label,尺寸为batchsize*label维度
y_sample[:, 7] = 1 #输出7
samples = sess.run(G_sample, feed_dict={Z: Z_sample, y:y_sample})#G网络的输入
fig = plot(samples)
plt.savefig('out/{}.png'.format(str(i).zfill(3)), bbox_inches='tight')#输出生成的图片
i += 1
plt.close(fig)
2.cgan--(由轮廓生成真实图)
image_reader.py
import os
import numpy as np
import tensorflow as tf
import cv2
#读取图片的函数,接收六个参数
#输入参数分别是图片名,图片路径,标签路径,图片格式,标签格式,需要调整的尺寸大小
def ImageReader(file_name, picture_path, label_path, picture_format = ".png", label_format = ".jpg", size = 256):
picture_name = picture_path + file_name + picture_format #得到图片名称和路径
label_name = label_path + file_name + label_format #得到标签名称和路径
picture = cv2.imread(picture_name, 1) #读取图片
label = cv2.imread(label_name, 1) #读取标签
height = picture.shape[0] #得到图片的高
width = picture.shape[1] #得到图片的宽
picture_resize_t = cv2.resize(picture, (size, size)) #调整图片的尺寸,改变成网络输入的大小
picture_resize = picture_resize_t / 127.5 - 1. #归一化图片
label_resize_t = cv2.resize(label, (size, size)) #调整标签的尺寸,改变成网络输入的大小
label_resize = label_resize_t / 127.5 - 1. #归一化标签
return picture_resize, label_resize, height, width #返回网络输入的图片,标签,还有原图片和标签的长宽
net.py
import numpy as np
import tensorflow as tf
import math
#构造可训练参数
def make_var(name, shape, trainable = True):
return tf.get_variable(name, shape, trainable = trainable)
#定义卷积层
def conv2d(input_, output_dim, kernel_size, stride, padding = "SAME", name = "conv2d", biased = False):
input_dim = input_.get_shape()[-1]
with tf.variable_scope(name):
kernel = make_var(name = 'weights', shape=[kernel_size, kernel_size, input_dim, output_dim])
output = tf.nn.conv2d(input_, kernel, [1, stride, stride, 1], padding = padding)
if biased:
biases = make_var(name = 'biases', shape = [output_dim])
output = tf.nn.bias_add(output, biases)
return output
#定义空洞卷积层
def atrous_conv2d(input_, output_dim, kernel_size, dilation, padding = "SAME", name = "atrous_conv2d", biased = False):
input_dim = input_.get_shape()[-1]
with tf.variable_scope(name):
kernel = make_var(name = 'weights', shape = [kernel_size, kernel_size, input_dim, output_dim])
output = tf.nn.atrous_conv2d(input_, kernel, dilation, padding = padding)
if biased:
biases = make_var(name = 'biases', shape = [output_dim])
output = tf.nn.bias_add(output, biases)
return output
#定义反卷积层
def deconv2d(input_, output_dim, kernel_size, stride, padding = "SAME", name = "deconv2d"):
input_dim = input_.get_shape()[-1]
input_height = int(input_.get_shape()[1])
input_width = int(input_.get_shape()[2])
with tf.variable_scope(name):
kernel = make_var(name = 'weights', shape = [kernel_size, kernel_size, output_dim, input_dim])
output = tf.nn.conv2d_transpose(input_, kernel, [1, input_height * 2, input_width * 2, output_dim], [1, 2, 2, 1], padding = "SAME")
return output
#定义batchnorm(批次归一化)层
def batch_norm(input_, name="batch_norm"):
with tf.variable_scope(name):
input_dim = input_.get_shape()[-1]
scale = tf.get_variable("scale", [input_dim], initializer=tf.random_normal_initializer(1.0, 0.02, dtype=tf.float32))
offset = tf.get_variable("offset", [input_dim], initializer=tf.constant_initializer(0.0))
mean, variance = tf.nn.moments(input_, axes=[1,2], keep_dims=True)
epsilon = 1e-5
inv = tf.rsqrt(variance + epsilon)
normalized = (input_-mean)*inv
output = scale*normalized + offset
return output
#定义lrelu激活层
def lrelu(x, leak=0.2, name = "lrelu"):
return tf.maximum(x, leak*x)
#定义生成器,采用UNet架构,主要由8个卷积层和8个反卷积层组成
def generator(image, gf_dim=64, reuse=False, name="generator"):
input_dim = int(image.get_shape()[-1]) #获取输入通道
dropout_rate = 0.5 #定义dropout的比例
with tf.variable_scope(name):
if reuse:
tf.get_variable_scope().reuse_variables()
else:
assert tf.get_variable_scope().reuse is False
#第一个卷积层,输出尺度[1, 128, 128, 64]
e1 = batch_norm(conv2d(input_=image, output_dim=gf_dim, kernel_size=4, stride=2, name='g_e1_conv'), name='g_bn_e1')
#第二个卷积层,输出尺度[1, 64, 64, 128]
e2 = batch_norm(conv2d(input_=lrelu(e1), output_dim=gf_dim*2, kernel_size=4, stride=2, name='g_e2_conv'), name='g_bn_e2')
#第三个卷积层,输出尺度[1, 32, 32, 256]
e3 = batch_norm(conv2d(input_=lrelu(e2), output_dim=gf_dim*4, kernel_size=4, stride=2, name='g_e3_conv'), name='g_bn_e3')
#第四个卷积层,输出尺度[1, 16, 16, 512]
e4 = batch_norm(conv2d(input_=lrelu(e3), output_dim=gf_dim*8, kernel_size=4, stride=2, name='g_e4_conv'), name='g_bn_e4')
#第五个卷积层,输出尺度[1, 8, 8, 512]
e5 = batch_norm(conv2d(input_=lrelu(e4), output_dim=gf_dim*8, kernel_size=4, stride=2, name='g_e5_conv'), name='g_bn_e5')
#第六个卷积层,输出尺度[1, 4, 4, 512]
e6 = batch_norm(conv2d(input_=lrelu(e5), output_dim=gf_dim*8, kernel_size=4, stride=2, name='g_e6_conv'), name='g_bn_e6')
#第七个卷积层,输出尺度[1, 2, 2, 512]
e7 = batch_norm(conv2d(input_=lrelu(e6), output_dim=gf_dim*8, kernel_size=4, stride=2, name='g_e7_conv'), name='g_bn_e7')
#第八个卷积层,输出尺度[1, 1, 1, 512]
e8 = batch_norm(conv2d(input_=lrelu(e7), output_dim=gf_dim*8, kernel_size=4, stride=2, name='g_e8_conv'), name='g_bn_e8')
#第一个反卷积层,输出尺度[1, 2, 2, 512]
d1 = deconv2d(input_=tf.nn.relu(e8), output_dim=gf_dim*8, kernel_size=4, stride=2, name='g_d1')
d1 = tf.nn.dropout(d1, dropout_rate) #随机扔掉一般的输出
d1 = tf.concat([batch_norm(d1, name='g_bn_d1'), e7], 3)
#第二个反卷积层,输出尺度[1, 4, 4, 512]
d2 = deconv2d(input_=tf.nn.relu(d1), output_dim=gf_dim*8, kernel_size=4, stride=2, name='g_d2')
d2 = tf.nn.dropout(d2, dropout_rate) #随机扔掉一般的输出
d2 = tf.concat([batch_norm(d2, name='g_bn_d2'), e6], 3)
#第三个反卷积层,输出尺度[1, 8, 8, 512]
d3 = deconv2d(input_=tf.nn.relu(d2), output_dim=gf_dim*8, kernel_size=4, stride=2, name='g_d3')
d3 = tf.nn.dropout(d3, dropout_rate) #随机扔掉一般的输出
d3 = tf.concat([batch_norm(d3, name='g_bn_d3'), e5], 3)
#第四个反卷积层,输出尺度[1, 16, 16, 512]
d4 = deconv2d(input_=tf.nn.relu(d3), output_dim=gf_dim*8, kernel_size=4, stride=2, name='g_d4')
d4 = tf.concat([batch_norm(d4, name='g_bn_d4'), e4], 3)
#第五个反卷积层,输出尺度[1, 32, 32, 256]
d5 = deconv2d(input_=tf.nn.relu(d4), output_dim=gf_dim*4, kernel_size=4, stride=2, name='g_d5')
d5 = tf.concat([batch_norm(d5, name='g_bn_d5'), e3], 3)
#第六个反卷积层,输出尺度[1, 64, 64, 128]
d6 = deconv2d(input_=tf.nn.relu(d5), output_dim=gf_dim*2, kernel_size=4, stride=2, name='g_d6')
d6 = tf.concat([batch_norm(d6, name='g_bn_d6'), e2], 3)
#第七个反卷积层,输出尺度[1, 128, 128, 64]
d7 = deconv2d(input_=tf.nn.relu(d6), output_dim=gf_dim, kernel_size=4, stride=2, name='g_d7')
d7 = tf.concat([batch_norm(d7, name='g_bn_d7'), e1], 3)
#第八个反卷积层,输出尺度[1, 256, 256, 3]
d8 = deconv2d(input_=tf.nn.relu(d7), output_dim=input_dim, kernel_size=4, stride=2, name='g_d8')
return tf.nn.tanh(d8)
#定义判别器
def discriminator(image, targets, df_dim=64, reuse=False, name="discriminator"):
with tf.variable_scope(name):
if reuse:
tf.get_variable_scope().reuse_variables()
else:
assert tf.get_variable_scope().reuse is False
dis_input = tf.concat([image, targets], 3)
#第1个卷积模块,输出尺度: 1*128*128*64
h0 = lrelu(conv2d(input_ = dis_input, output_dim = df_dim, kernel_size = 4, stride = 2, name='d_h0_conv'))
#第2个卷积模块,输出尺度: 1*64*64*128
h1 = lrelu(batch_norm(conv2d(input_ = h0, output_dim = df_dim*2, kernel_size = 4, stride = 2, name='d_h1_conv'), name='d_bn1'))
#第3个卷积模块,输出尺度: 1*32*32*256
h2 = lrelu(batch_norm(conv2d(input_ = h1, output_dim = df_dim*4, kernel_size = 4, stride = 2, name='d_h2_conv'), name='d_bn2'))
#第4个卷积模块,输出尺度: 1*32*32*512
h3 = lrelu(batch_norm(conv2d(input_ = h2, output_dim = df_dim*8, kernel_size = 4, stride = 1, name='d_h3_conv'), name='d_bn3'))
#最后一个卷积模块,输出尺度: 1*32*32*1
output = conv2d(input_ = h3, output_dim = 1, kernel_size = 4, stride = 1, name='d_h4_conv')
dis_out = tf.sigmoid(output) #在输出之前经过sigmoid层,因为需要进行log运算
return dis_out
train.py
from __future__ import print_function
import argparse
from random import shuffle
import random
import os
import sys
import math
import tensorflow as tf
import glob
import cv2
from image_reader import *
from net import *
parser = argparse.ArgumentParser(description='')
parser.add_argument("--snapshot_dir", default='./snapshots', help="path of snapshots") #保存模型的路径
parser.add_argument("--out_dir", default='./train_out', help="path of train outputs") #训练时保存可视化输出的路径
parser.add_argument("--image_size", type=int, default=256, help="load image size") #网络输入的尺度
parser.add_argument("--random_seed", type=int, default=1234, help="random seed") #随机数种子
parser.add_argument('--base_lr', type=float, default=0.0002, help='initial learning rate for adam') #学习率
parser.add_argument('--epoch', dest='epoch', type=int, default=200, help='# of epoch') #训练的epoch数量
parser.add_argument('--beta1', dest='beta1', type=float, default=0.5, help='momentum term of adam') #adam优化器的beta1参数
parser.add_argument("--summary_pred_every", type=int, default=200, help="times to summary.") #训练中每过多少step保存训练日志(记录一下loss值)
parser.add_argument("--write_pred_every", type=int, default=100, help="times to write.") #训练中每过多少step保存可视化结果
parser.add_argument("--save_pred_every", type=int, default=1000, help="times to save.") #训练中每过多少step保存模型(可训练参数)
parser.add_argument("--lamda_l1_weight", type=float, default=0.0, help="L1 lamda") #训练中L1_Loss前的乘数
parser.add_argument("--lamda_gan_weight", type=float, default=1.0, help="GAN lamda") #训练中GAN_Loss前的乘数
parser.add_argument("--train_picture_format", default='.jpg', help="format of training datas.") #网络训练输入的图片的格式(图片在CGAN中被当做条件)
parser.add_argument("--train_label_format", default='.jpg', help="format of training labels.") #网络训练输入的标签的格式(标签在CGAN中被当做真样本)
parser.add_argument("--train_picture_path", default='img/', help="path of training datas.") #网络训练输入的图片路径
parser.add_argument("--train_label_path", default='img2/', help="path of training labels.") #网络训练输入的标签路径
args = parser.parse_args() #用来解析命令行参数
EPS = 1e-12 #EPS用于保证log函数里面的参数大于零
def save(saver, sess, logdir, step): #保存模型的save函数
model_name = 'model' #保存的模型名前缀
checkpoint_path = os.path.join(logdir, model_name) #模型的保存路径与名称
if not os.path.exists(logdir): #如果路径不存在即创建
os.makedirs(logdir)
saver.save(sess, checkpoint_path, global_step=step) #保存模型
print('The checkpoint has been created.')
def cv_inv_proc(img): #cv_inv_proc函数将读取图片时归一化的图片还原成原图
img_rgb = (img + 1.) * 127.5
return img_rgb.astype(np.float32) #返回bgr格式的图像,方便cv2写图像
def get_write_picture(picture, gen_label, label, height, width): #get_write_picture函数得到训练过程中的可视化结果
picture_image = cv_inv_proc(picture) #还原输入的图像
gen_label_image = cv_inv_proc(gen_label[0]) #还原生成的样本
label_image = cv_inv_proc(label) #还原真实的样本(标签)
inv_picture_image = cv2.resize(picture_image, (width, height)) #还原图像的尺寸
inv_gen_label_image = cv2.resize(gen_label_image, (width, height)) #还原生成的样本的尺寸
inv_label_image = cv2.resize(label_image, (width, height)) #还原真实的样本的尺寸
output = np.concatenate((inv_picture_image, inv_gen_label_image, inv_label_image), axis=1) #把他们拼起来
return output
def l1_loss(src, dst): #定义l1_loss
return tf.reduce_mean(tf.abs(src - dst))
def main(): #训练程序的主函数
if not os.path.exists(args.snapshot_dir): #如果保存模型参数的文件夹不存在则创建
os.makedirs(args.snapshot_dir)
if not os.path.exists(args.out_dir): #如果保存训练中可视化输出的文件夹不存在则创建
os.makedirs(args.out_dir)
#获取图片列表,构造一个图片和一个标签
train_picture_list = glob.glob(os.path.join(args.train_picture_path, "*")) #得到训练输入图像路径名称列表
tf.set_random_seed(args.random_seed) #初始一下随机数
train_picture = tf.placeholder(tf.float32,shape=[1, args.image_size, args.image_size, 3],name='train_picture') #输入的训练图像
train_label = tf.placeholder(tf.float32,shape=[1, args.image_size, args.image_size, 3],name='train_label') #输入的与训练图像匹配的标签
#生成器生成图片,判别器输出对真实标签(train_label)、对生成标签(gen_label)的判断结果
gen_label = generator(image=train_picture, gf_dim=64, reuse=False, name='generator') #得到生成器的输出
dis_real = discriminator(image=train_picture, targets=train_label, df_dim=64, reuse=False, name="discriminator") #判别器返回的对真实标签的判别结果
dis_fake = discriminator(image=train_picture, targets=gen_label, df_dim=64, reuse=True, name="discriminator") #判别器返回的对生成(虚假的)标签判别结果
#计算生成器的损失
gen_loss_GAN = tf.reduce_mean(-tf.log(dis_fake + EPS)) #计算生成器损失中的GAN_loss部分
gen_loss_L1 = tf.reduce_mean(l1_loss(gen_label, train_label)) #计算生成器损失中的L1_loss部分
gen_loss = gen_loss_GAN * args.lamda_gan_weight + gen_loss_L1 * args.lamda_l1_weight #计算生成器的loss
#计算判别器的损失
dis_loss = tf.reduce_mean(-(tf.log(dis_real + EPS) + tf.log(1 - dis_fake + EPS))) #计算判别器的loss
gen_loss_sum = tf.summary.scalar("gen_loss", gen_loss) #记录生成器loss的日志
dis_loss_sum = tf.summary.scalar("dis_loss", dis_loss) #记录判别器loss的日志
#把日志输出到文件夹-snapshot_dir
summary_writer = tf.summary.FileWriter(args.snapshot_dir, graph=tf.get_default_graph()) #日志记录器
g_vars = [v for v in tf.trainable_variables() if 'generator' in v.name] #所有生成器的可训练参数
d_vars = [v for v in tf.trainable_variables() if 'discriminator' in v.name] #所有判别器的可训练参数
#计算梯度
d_optim = tf.train.AdamOptimizer(args.base_lr, beta1=args.beta1) #判别器训练器
g_optim = tf.train.AdamOptimizer(args.base_lr, beta1=args.beta1) #生成器训练器
d_grads_and_vars = d_optim.compute_gradients(dis_loss, var_list=d_vars) #计算判别器参数梯度
d_train = d_optim.apply_gradients(d_grads_and_vars) #更新判别器参数
g_grads_and_vars = g_optim.compute_gradients(gen_loss, var_list=g_vars) #计算生成器参数梯度
g_train = g_optim.apply_gradients(g_grads_and_vars) #更新生成器参数
#设置GPU
train_op = tf.group(d_train, g_train) #train_op表示了参数更新操作
config = tf.ConfigProto()
config.gpu_options.allow_growth = True #设定显存不超量使用
#定义会话
sess = tf.Session(config=config) #新建会话层
init = tf.global_variables_initializer() #参数初始化器
sess.run(init) #初始化所有可训练参数
saver = tf.train.Saver(var_list=tf.global_variables(), max_to_keep=50) #模型保存器
counter = 0 #counter记录训练步数
for epoch in range(args.epoch): #训练epoch数
shuffle(train_picture_list) #每训练一个epoch,就打乱一下输入的顺序
for step in range(len(train_picture_list)): #每个训练epoch中的训练step数
counter += 1
picture_name, _ = os.path.splitext(os.path.basename(train_picture_list[step])) #获取不包含路径和格式的输入图片名称
#读取一张训练图片,一张训练标签,以及相应的高和宽
picture_resize, label_resize, picture_height, picture_width = ImageReader(file_name=picture_name,
picture_path=args.train_picture_path, label_path=args.train_label_path,
picture_format = args.train_picture_format, label_format = args.train_label_format,
size = args.image_size)
#给图片和标签填充维度
batch_picture = np.expand_dims(np.array(picture_resize).astype(np.float32), axis = 0) #填充维度
batch_label = np.expand_dims(np.array(label_resize).astype(np.float32), axis = 0) #填充维度
print(batch_picture.shape, type(batch_label))
#喂入数据,计算各个损失
feed_dict = { train_picture : batch_picture, train_label : batch_label } #构造feed_dict
gen_loss_value, dis_loss_value, _ = sess.run([gen_loss, dis_loss, train_op], feed_dict=feed_dict) #得到每个step中的生成器和判别器loss
#按步数保存各种数据:模型、日志、可视化结果
if counter % args.save_pred_every == 100: #每过save_pred_every次保存模型
save(saver, sess, args.snapshot_dir, counter)
if counter % args.summary_pred_every == 100: #每过summary_pred_every次保存训练日志
gen_loss_sum_value, discriminator_sum_value = sess.run([gen_loss_sum, dis_loss_sum], feed_dict=feed_dict)
summary_writer.add_summary(gen_loss_sum_value, counter)
summary_writer.add_summary(discriminator_sum_value, counter)
if counter % args.write_pred_every == 100: #每过write_pred_every次写一下训练的可视化结果
gen_label_value = sess.run(gen_label, feed_dict=feed_dict) #run出生成器的输出
write_image = get_write_picture(picture_resize, gen_label_value, label_resize, picture_height, picture_width) #得到训练的可视化结果
write_image_name = args.out_dir + "/out"+ str(counter) + ".png" #待保存的训练可视化结果路径与名称
cv2.imwrite(write_image_name, write_image) #保存训练的可视化结果
print('epoch {:d} step {:d} \t gen_loss = {:.3f}, dis_loss = {:.3f}'.format(epoch, step, gen_loss_value, dis_loss_value))
if __name__ == '__main__':
main()
evaluate.py
import argparse
import sys
import math
import tensorflow as tf
import numpy as np
import glob
import cv2
from image_reader import *
from net import *
parser = argparse.ArgumentParser(description='')
parser.add_argument("--test_picture_path", default='img/', help="path of test datas.")#网络测试输入的图片路径
parser.add_argument("--test_label_path", default='img2/', help="path of test datas.") #网络测试输入的标签路径
parser.add_argument("--image_size", type=int, default=256, help="load image size") #网络输入的尺度
parser.add_argument("--test_picture_format", default='.jpg', help="format of test pictures.") #网络测试输入的图片的格式
parser.add_argument("--test_label_format", default='.jpg', help="format of test labels.") #网络测试时读取的标签的格式
parser.add_argument("--snapshots", default='snapshots/',help="Path of Snapshots") #读取训练好的模型参数的路径
parser.add_argument("--out_dir", default='test_output/',help="Output Folder") #保存网络测试输出图片的路径
args = parser.parse_args() #用来解析命令行参数
def cv_inv_proc(img): #cv_inv_proc函数将读取图片时归一化的图片还原成原图
img_rgb = (img + 1.) * 127.5
return img_rgb.astype(np.float32) #返回bgr格式的图像,方便cv2写图像
def get_write_picture(picture, gen_label, label, height, width): #get_write_picture函数得到网络测试的结果
picture_image = cv_inv_proc(picture) #还原输入的图像
gen_label_image = cv_inv_proc(gen_label[0]) #还原生成的结果
label_image = cv_inv_proc(label) #还原读取的标签
inv_picture_image = cv2.resize(picture_image, (width, height)) #将输入图像还原到原大小
inv_gen_label_image = cv2.resize(gen_label_image, (width, height)) #将生成的结果还原到原大小
inv_label_image = cv2.resize(label_image, (width, height)) #将标签还原到原大小
output = np.concatenate((inv_picture_image, inv_gen_label_image, inv_label_image), axis=1) #拼接得到输出结果
return output
def main():
if not os.path.exists(args.out_dir): #如果保存测试结果的文件夹不存在则创建
os.makedirs(args.out_dir)
test_picture_list = glob.glob(os.path.join(args.test_picture_path, "*")) #得到测试输入图像路径名称列表
test_picture = tf.placeholder(tf.float32, shape=[1, 256, 256, 3], name='test_picture') #测试输入的图像
gen_label = generator(image=test_picture, gf_dim=64, reuse=False, name='generator') #得到生成器的生成结果
restore_var = [v for v in tf.global_variables() if 'generator' in v.name] #需要载入的已训练的模型参数
config = tf.ConfigProto()
config.gpu_options.allow_growth = True #设定显存不超量使用
sess = tf.Session(config=config) #建立会话层
saver = tf.train.Saver(var_list=restore_var, max_to_keep=1) #导入模型参数时使用
checkpoint = tf.train.latest_checkpoint(args.snapshots) #读取模型参数
saver.restore(sess, checkpoint) #导入模型参数
for step in range(len(test_picture_list)):
picture_name, _ = os.path.splitext(os.path.basename(test_picture_list[step])) #得到一张网络测试的输入图像名字
#读取一张测试图片,一张标签,以及相应的高和宽
picture_resize, label_resize, picture_height, picture_width = ImageReader(file_name=picture_name,
picture_path=args.test_picture_path,
label_path=args.test_label_path,
picture_format=args.test_picture_format,
label_format=args.test_label_format,
size=args.image_size)
batch_picture = np.expand_dims(np.array(picture_resize).astype(np.float32), axis=0) #填充维度
feed_dict = {test_picture: batch_picture} #构造feed_dict
gen_label_value = sess.run(gen_label, feed_dict=feed_dict) #得到生成结果
write_image = get_write_picture(picture_resize, gen_label_value, label_resize, picture_height, picture_width) #得到一张需要存的图像
write_image_name = args.out_dir + picture_name + ".png" #为上述的图像构造保存路径与文件名
cv2.imwrite(write_image_name, write_image) #保存测试结果
print('step {:d}'.format(step))
if __name__ == '__main__':
main()