判别器网络:
(?,64,192,3) ->1 (?,32,96,64) -> 2(?.16,48,128) -> 3(?,8,24,256) ->4 (?,8,24,512) ->5 (?,8,24,1)
生成器网络:
(?,64,192,3) ->1 (?,64,192,64) -> 2(?,32,96,128) -> 3(?,16,48,256) -> 4(r1->r9九个残差模块)(?,16,48,256) -> 5(?,32,96,128) -> 6(?,70,198,64) -> 7(?,64,192,3)
siamese network
(?,64,192,3) -> (?,32,96,64)->(?,16,48,64) -> (?,8,24,128)->(?,4,12,128) -> (?,2,6,256)->(?,1,3,256) -> (?,1,2,512) -> (?,1024) -> (?,128) -> (?,64)
代码分析:
with tf.device('/gpu:%d' % gpu_id):
''' graph '''
# --nodes--#
a_real = tf.placeholder(tf.float32, shape=[None, crop_sizeh, crop_sizew, 3]) # 真a
b_real = tf.placeholder(tf.float32, shape=[None, crop_sizeh, crop_sizew, 3]) # 真b
a2b_sample = tf.placeholder(tf.float32, shape=[None, crop_sizeh, crop_sizew, 3]) # a生成b
b2a_sample = tf.placeholder(tf.float32, shape=[None, crop_sizeh, crop_sizew, 3]) # b生成a
a2b = models.generator(a_real, 'a2b') # 真a -> 假b a2b表示a生成b的生成器G (?, 64, 192, 3)
b2a = models.generator(b_real, 'b2a') # 真b -> 假a b2a表示b生成a的生成器F (?, 64, 192, 3)
b2a2b = models.generator(b2a, 'a2b', reuse=True) # 真b -> 假a -> 重构回b (?, 64, 192, 3)
a2b2a = models.generator(a2b, 'b2a', reuse=True) # 真a -> 假b -> 重构回a (?, 64, 192, 3)
b2b = models.generator(b_real, 'a2b', reuse=True) # 真b G (?, 64, 192, 3)
a2a = models.generator(a_real, 'b2a', reuse=True) # 真a F (?, 64, 192, 3)
a_dis = models.discriminator(a_real, 'a') # 判别器a 对真a (?, 8, 24, 1)
b2a_dis = models.discriminator(b2a, 'a', reuse=True) # 判别器a 对假a (?, 8, 24, 1)
b2a_sample_dis = models.discriminator(b2a_sample, 'a', reuse=True) # ??? (?, 8, 24, 1)
# reuse表明存在参数就复用,否则重新创建
b_dis = models.discriminator(b_real, 'b') # 判别器b 对真b (?, 8, 24, 1)
a2b_dis = models.discriminator(a2b, 'b', reuse=True) # 判别器b 对假b (?, 8, 24, 1)
a2b_sample_dis = models.discriminator(a2b_sample, 'b', reuse=True) # ??? (?, 8, 24, 1)
# 以上部分是初始化 a->b的生成器/判别器 b->a的生成器/判别器
# --siamese network--#
a_metric = tf.nn.l2_normalize(models.metric_net(a_real, 'metric'), 1) # 对 真a tf.nn.l2_normalize 为按行进行l2范化
a2b_metric = tf.nn.l2_normalize(models.metric_net(a2b, 'metric', reuse=True), 1) # 对 假b shape=(?, 64)
b_metric = tf.nn.l2_normalize(models.metric_net(b_real, 'metric', reuse=True), 1) # 对 真b
b2a_metric = tf.nn.l2_normalize(models.metric_net(b2a, 'metric', reuse=True), 1) # 对 假a
# --Postive Pair-- #正配对 公式(5) i = 1
C = tf.constant(margin, name="C") # 2.0
# a图 与 a生成的b
S_eucd_pos = tf.pow(tf.subtract(a_metric, a2b_metric), 2) # tf.subtract 减法运算 tf.pow 幂次方 整体求欧氏距离 (?, 64)
S_metric_POS = tf.reduce_sum(S_eucd_pos, 1) # 按行求和
# b图 与 b生成的a
T_eucd_pos = tf.pow(tf.subtract(b_metric, b2a_metric), 2)
T_metric_POS = tf.reduce_sum(T_eucd_pos, 1)
# --Negative Pair-- #负配对 公式(5) i = 0 m =2.0 即C
# a图与b图
neg = tf.pow(tf.subtract(a_metric, b_metric), 2) # 图a 与 图b间 的 差
neg = tf.reduce_sum(neg, 1)
neg = tf.sqrt(neg + 1e-6) # 开根号
NEG = tf.pow(tf.maximum(tf.subtract(C, neg), 0), 2) # C减去 neg 与0相比, 取最大值, 保证为不为负值
# --contrastive loss--#
m_loss = tf.identity((T_metric_POS + S_metric_POS + 2 * NEG) / 3.0, name='metric_losses') #
# --losses--#
g_loss_a2b = tf.identity(ops.l2_loss(a2b_dis, tf.ones_like(a2b_dis)), name='g_loss_a2b') # a生成的b ,通过b的判别器趋近于1
g_loss_b2a = tf.identity(ops.l2_loss(b2a_dis, tf.ones_like(b2a_dis)), name='g_loss_b2a')
g_orig = g_loss_a2b + g_loss_b2a # 生成器损失
cyc_loss_a = tf.identity(ops.l1_loss(a_real, a2b2a) * lambda1, name='cyc_loss_a') # a 本身 a->b ->a 的损失
cyc_loss_b = tf.identity(ops.l1_loss(b_real, b2a2b) * lambda1, name='cyc_loss_b')
cyc_loss = cyc_loss_a + cyc_loss_b
# --identity loss--# 图b自己通过b2b网络,再与自身求损失
idt_losss_b = tf.identity(ops.l1_loss(b2b, b_real) * lambda2, name='idt_loss_b')
idt_losss_a = tf.identity(ops.l1_loss(a2a, a_real) * lambda2, name='idt_loss_a')
idt_loss = idt_losss_b + idt_losss_a
g_loss = g_loss_a2b + g_loss_b2a + cyc_loss + idt_loss + lambda3 * m_loss # 公式(6)
d_loss_a_real = ops.l2_loss(a_dis, tf.ones_like(a_dis))
d_loss_b2a_sample = ops.l2_loss(b2a_sample_dis, tf.zeros_like(b2a_sample_dis))
d_loss_a = tf.identity((d_loss_a_real + d_loss_b2a_sample) / 2.0, name='d_loss_a') # 判别器a的损失
d_loss_b_real = ops.l2_loss(b_dis, tf.ones_like(b_dis))
d_loss_a2b_sample = ops.l2_loss(a2b_sample_dis, tf.zeros_like(a2b_sample_dis))
d_loss_b = tf.identity((d_loss_b_real + d_loss_a2b_sample) / 2.0, name='d_loss_b') # 判别器b的损失
# --summaries--# 四部分损失
g_summary = ops.summary_tensors([g_loss_a2b, g_loss_b2a, cyc_loss_a, cyc_loss_b, idt_losss_a, idt_losss_b, m_loss])
d_summary_a = ops.summary(d_loss_a)
d_summary_b = ops.summary(d_loss_b)
metric_summary = ops.summary(m_loss)
''' optim '''
t_var = tf.trainable_variables()
d_a_var = [var for var in t_var if 'a_discriminator' in var.name]
d_b_var = [var for var in t_var if 'b_discriminator' in var.name]
g_var = [var for var in t_var if 'a2b_generator' in var.name or 'b2a_generator' in var.name]
metric_var = [var for var in t_var if 'metric_discriminator' in var.name]
# 优化器用了4个,分别做各自的优化
d_a_train_op = tf.train.AdamOptimizer(lr, beta1=0.5).minimize(d_loss_a, var_list=d_a_var)
d_b_train_op = tf.train.AdamOptimizer(lr, beta1=0.5).minimize(d_loss_b, var_list=d_b_var)
g_train_op = tf.train.AdamOptimizer(lr, beta1=0.5).minimize(g_loss, var_list=g_var)
metric_train_op = tf.train.AdamOptimizer(lr, beta1=0.5).minimize(m_loss, var_list=metric_var)
""" train """
''' init '''
# --session--#
gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.5)
tfconfig = tf.ConfigProto(allow_soft_placement=True, gpu_options=gpu_options)
tfconfig.gpu_options.allow_growth = True
config = tf.ConfigProto(allow_soft_placement=True)
sess = tf.Session(config=config)
# --counter--#
it_cnt, update_cnt = ops.counter()
'''data'''
a_img_paths = glob('./Datasets/' + dataset + '/bounding_box_train-Duke/*.jpg')
b_img_paths = glob('./Datasets/' + dataset + '/bounding_box_train-Market/*.jpg')
a_data_pool = data.ImageData(sess, a_img_paths, batch_size, load_size=load_size, crop_size=crop_sizeh) # a数据集
b_data_pool = data.ImageData(sess, b_img_paths, batch_size, load_size=load_size, crop_size=crop_sizeh) # b数据集
a_test_img_paths = glob('./Datasets/' + dataset + '/bounding_box_train-Duke/*.jpg')
b_test_img_paths = glob('./Datasets/' + dataset + '/bounding_box_train-Market/*.jpg')
a_test_pool = data.ImageData(sess, a_test_img_paths, batch_size, load_size=load_size, crop_size=crop_sizeh)
b_test_pool = data.ImageData(sess, b_test_img_paths, batch_size, load_size=load_size, crop_size=crop_sizeh)
a2b_pool = utils.ItemPool()
b2a_pool = utils.ItemPool()
'''summary'''
summary_writer = tf.summary.FileWriter('./summaries/' + dataset + '_spgan', sess.graph)
'''saver'''
ckpt_dir = './checkpoints/' + dataset + '_spgan'
utils.mkdir(ckpt_dir + '/')
saver = tf.train.Saver(max_to_keep=10000)
ckpt_path = utils.load_checkpoint(ckpt_dir, sess, saver)
if ckpt_path is None:
sess.run(tf.global_variables_initializer())
else:
print('Copy variables from % s' % ckpt_path)
'''train'''
try:
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
batch_epoch = min(len(a_data_pool), len(b_data_pool)) // batch_size # 最少的数据集量
max_it = epoch * batch_epoch # 迭代次数*最少数据集量
now = time.strftime("%c")
print('================ Begining Training time (%s) ================\n' % now)
for it in range(sess.run(it_cnt), max_it):
np.random.seed(0)
tf.set_random_seed(0)
sess.run(update_cnt)
# --prepare data--#
a_real_ipt = a_data_pool.batch()
b_real_ipt = b_data_pool.batch()
a2b_opt, b2a_opt = sess.run([a2b, b2a], feed_dict={a_real: a_real_ipt, b_real: b_real_ipt})
a2b_sample_ipt = np.array(a2b_pool(list(a2b_opt)))
b2a_sample_ipt = np.array(b2a_pool(list(b2a_opt)))
# --train G--#
g_summary_opt, _ = sess.run([g_summary, g_train_op], feed_dict={a_real: a_real_ipt, b_real: b_real_ipt})
summary_writer.add_summary(g_summary_opt, it)
# --train D_b--#
d_summary_b_opt, _ = sess.run([d_summary_b, d_b_train_op],
feed_dict={b_real: b_real_ipt, a2b_sample: a2b_sample_ipt})
summary_writer.add_summary(d_summary_b_opt, it)
# --train D_a--#
d_summary_a_opt, _ = sess.run([d_summary_a, d_a_train_op],
feed_dict={a_real: a_real_ipt, b2a_sample: b2a_sample_ipt})
summary_writer.add_summary(d_summary_a_opt, it)
# --train metric--#
metric_summary_opt, _ = sess.run([metric_summary, metric_train_op],
feed_dict={a_real: a_real_ipt, b_real: b_real_ipt, a2b: a2b_opt, b2a: b2a_opt})
summary_writer.add_summary(metric_summary_opt, it)