spgan网络

判别器网络:
(?,64,192,3) ->1 (?,32,96,64) -> 2(?.16,48,128) -> 3(?,8,24,256) ->4 (?,8,24,512) ->5 (?,8,24,1)
生成器网络:
(?,64,192,3) ->1 (?,64,192,64) -> 2(?,32,96,128) -> 3(?,16,48,256) -> 4(r1->r9九个残差模块)(?,16,48,256) -> 5(?,32,96,128) -> 6(?,70,198,64) -> 7(?,64,192,3)
siamese network
(?,64,192,3) -> (?,32,96,64)->(?,16,48,64) -> (?,8,24,128)->(?,4,12,128) -> (?,2,6,256)->(?,1,3,256) -> (?,1,2,512) -> (?,1024) -> (?,128) -> (?,64)

代码分析:

with tf.device('/gpu:%d' % gpu_id):
    ''' graph '''
    # --nodes--#
    a_real = tf.placeholder(tf.float32, shape=[None, crop_sizeh, crop_sizew, 3])  # 真a
    b_real = tf.placeholder(tf.float32, shape=[None, crop_sizeh, crop_sizew, 3])  # 真b
    a2b_sample = tf.placeholder(tf.float32, shape=[None, crop_sizeh, crop_sizew, 3])  # a生成b
    b2a_sample = tf.placeholder(tf.float32, shape=[None, crop_sizeh, crop_sizew, 3])  # b生成a

    a2b = models.generator(a_real, 'a2b')  # 真a -> 假b   a2b表示a生成b的生成器G  (?, 64, 192, 3)
    b2a = models.generator(b_real, 'b2a')  # 真b -> 假a   b2a表示b生成a的生成器F  (?, 64, 192, 3)
    b2a2b = models.generator(b2a, 'a2b', reuse=True)  # 真b -> 假a -> 重构回b   (?, 64, 192, 3)
    a2b2a = models.generator(a2b, 'b2a', reuse=True)  # 真a -> 假b -> 重构回a   (?, 64, 192, 3)

    b2b = models.generator(b_real, 'a2b', reuse=True)  # 真b G   (?, 64, 192, 3)
    a2a = models.generator(a_real, 'b2a', reuse=True)  # 真a F   (?, 64, 192, 3)

    a_dis = models.discriminator(a_real, 'a')  # 判别器a 对真a   (?, 8, 24, 1)
    b2a_dis = models.discriminator(b2a, 'a', reuse=True)  # 判别器a 对假a   (?, 8, 24, 1)
    b2a_sample_dis = models.discriminator(b2a_sample, 'a', reuse=True)  # ???   (?, 8, 24, 1)
    # reuse表明存在参数就复用,否则重新创建
    b_dis = models.discriminator(b_real, 'b')  # 判别器b 对真b   (?, 8, 24, 1)
    a2b_dis = models.discriminator(a2b, 'b', reuse=True)  # 判别器b 对假b   (?, 8, 24, 1)
    a2b_sample_dis = models.discriminator(a2b_sample, 'b', reuse=True)  # ???  (?, 8, 24, 1)

    # 以上部分是初始化 a->b的生成器/判别器 b->a的生成器/判别器
    # --siamese network--#
    a_metric = tf.nn.l2_normalize(models.metric_net(a_real, 'metric'), 1)  # 对 真a  tf.nn.l2_normalize 为按行进行l2范化
    a2b_metric = tf.nn.l2_normalize(models.metric_net(a2b, 'metric', reuse=True), 1)  # 对 假b  shape=(?, 64)

    b_metric = tf.nn.l2_normalize(models.metric_net(b_real, 'metric', reuse=True), 1)  # 对 真b
    b2a_metric = tf.nn.l2_normalize(models.metric_net(b2a, 'metric', reuse=True), 1)  # 对 假a

    # --Postive Pair--  #正配对 公式(5) i = 1
    C = tf.constant(margin, name="C")  # 2.0
    # a图 与 a生成的b
    S_eucd_pos = tf.pow(tf.subtract(a_metric, a2b_metric), 2)  # tf.subtract 减法运算 tf.pow 幂次方 整体求欧氏距离 (?, 64)
    S_metric_POS = tf.reduce_sum(S_eucd_pos, 1)  # 按行求和
    # b图 与 b生成的a
    T_eucd_pos = tf.pow(tf.subtract(b_metric, b2a_metric), 2)
    T_metric_POS = tf.reduce_sum(T_eucd_pos, 1)

    # --Negative Pair-- #负配对  公式(5) i = 0  m =2.0 即C
    # a图与b图
    neg = tf.pow(tf.subtract(a_metric, b_metric), 2)  # 图a 与 图b间 的 差
    neg = tf.reduce_sum(neg, 1)
    neg = tf.sqrt(neg + 1e-6)  # 开根号
    NEG = tf.pow(tf.maximum(tf.subtract(C, neg), 0), 2)  # C减去 neg 与0相比, 取最大值, 保证为不为负值

    # --contrastive loss--#
    m_loss = tf.identity((T_metric_POS + S_metric_POS + 2 * NEG) / 3.0, name='metric_losses')  #

    # --losses--#
    g_loss_a2b = tf.identity(ops.l2_loss(a2b_dis, tf.ones_like(a2b_dis)), name='g_loss_a2b')  # a生成的b ,通过b的判别器趋近于1
    g_loss_b2a = tf.identity(ops.l2_loss(b2a_dis, tf.ones_like(b2a_dis)), name='g_loss_b2a')
    g_orig = g_loss_a2b + g_loss_b2a  # 生成器损失

    cyc_loss_a = tf.identity(ops.l1_loss(a_real, a2b2a) * lambda1, name='cyc_loss_a')  # a 本身 a->b ->a 的损失
    cyc_loss_b = tf.identity(ops.l1_loss(b_real, b2a2b) * lambda1, name='cyc_loss_b')
    cyc_loss = cyc_loss_a + cyc_loss_b

    # --identity loss--#  图b自己通过b2b网络,再与自身求损失
    idt_losss_b = tf.identity(ops.l1_loss(b2b, b_real) * lambda2, name='idt_loss_b')
    idt_losss_a = tf.identity(ops.l1_loss(a2a, a_real) * lambda2, name='idt_loss_a')
    idt_loss = idt_losss_b + idt_losss_a

    g_loss = g_loss_a2b + g_loss_b2a + cyc_loss + idt_loss + lambda3 * m_loss  # 公式(6)

    d_loss_a_real = ops.l2_loss(a_dis, tf.ones_like(a_dis))
    d_loss_b2a_sample = ops.l2_loss(b2a_sample_dis, tf.zeros_like(b2a_sample_dis))
    d_loss_a = tf.identity((d_loss_a_real + d_loss_b2a_sample) / 2.0, name='d_loss_a')  # 判别器a的损失

    d_loss_b_real = ops.l2_loss(b_dis, tf.ones_like(b_dis))
    d_loss_a2b_sample = ops.l2_loss(a2b_sample_dis, tf.zeros_like(a2b_sample_dis))
    d_loss_b = tf.identity((d_loss_b_real + d_loss_a2b_sample) / 2.0, name='d_loss_b')  # 判别器b的损失

    # --summaries--# 四部分损失
    g_summary = ops.summary_tensors([g_loss_a2b, g_loss_b2a, cyc_loss_a, cyc_loss_b, idt_losss_a, idt_losss_b, m_loss])
    d_summary_a = ops.summary(d_loss_a)
    d_summary_b = ops.summary(d_loss_b)
    metric_summary = ops.summary(m_loss)

    ''' optim '''
    t_var = tf.trainable_variables()
    d_a_var = [var for var in t_var if 'a_discriminator' in var.name]
    d_b_var = [var for var in t_var if 'b_discriminator' in var.name]
    g_var = [var for var in t_var if 'a2b_generator' in var.name or 'b2a_generator' in var.name]
    metric_var = [var for var in t_var if 'metric_discriminator' in var.name]

    # 优化器用了4个,分别做各自的优化
    d_a_train_op = tf.train.AdamOptimizer(lr, beta1=0.5).minimize(d_loss_a, var_list=d_a_var)
    d_b_train_op = tf.train.AdamOptimizer(lr, beta1=0.5).minimize(d_loss_b, var_list=d_b_var)
    g_train_op = tf.train.AdamOptimizer(lr, beta1=0.5).minimize(g_loss, var_list=g_var)
    metric_train_op = tf.train.AdamOptimizer(lr, beta1=0.5).minimize(m_loss, var_list=metric_var)

""" train """
''' init '''
# --session--#
gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.5)
tfconfig = tf.ConfigProto(allow_soft_placement=True, gpu_options=gpu_options)
tfconfig.gpu_options.allow_growth = True
config = tf.ConfigProto(allow_soft_placement=True)
sess = tf.Session(config=config)
# --counter--#
it_cnt, update_cnt = ops.counter()

'''data'''
a_img_paths = glob('./Datasets/' + dataset + '/bounding_box_train-Duke/*.jpg')
b_img_paths = glob('./Datasets/' + dataset + '/bounding_box_train-Market/*.jpg')
a_data_pool = data.ImageData(sess, a_img_paths, batch_size, load_size=load_size, crop_size=crop_sizeh)  # a数据集
b_data_pool = data.ImageData(sess, b_img_paths, batch_size, load_size=load_size, crop_size=crop_sizeh)  # b数据集

a_test_img_paths = glob('./Datasets/' + dataset + '/bounding_box_train-Duke/*.jpg')
b_test_img_paths = glob('./Datasets/' + dataset + '/bounding_box_train-Market/*.jpg')
a_test_pool = data.ImageData(sess, a_test_img_paths, batch_size, load_size=load_size, crop_size=crop_sizeh)
b_test_pool = data.ImageData(sess, b_test_img_paths, batch_size, load_size=load_size, crop_size=crop_sizeh)

a2b_pool = utils.ItemPool()
b2a_pool = utils.ItemPool()

'''summary'''
summary_writer = tf.summary.FileWriter('./summaries/' + dataset + '_spgan', sess.graph)

'''saver'''
ckpt_dir = './checkpoints/' + dataset + '_spgan'
utils.mkdir(ckpt_dir + '/')

saver = tf.train.Saver(max_to_keep=10000)
ckpt_path = utils.load_checkpoint(ckpt_dir, sess, saver)
if ckpt_path is None:
    sess.run(tf.global_variables_initializer())
else:
    print('Copy variables from % s' % ckpt_path)

'''train'''
try:
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)

    batch_epoch = min(len(a_data_pool), len(b_data_pool)) // batch_size  # 最少的数据集量
    max_it = epoch * batch_epoch  # 迭代次数*最少数据集量
    now = time.strftime("%c")
    print('================ Begining Training time (%s) ================\n' % now)
    for it in range(sess.run(it_cnt), max_it):
        np.random.seed(0)
        tf.set_random_seed(0)
        sess.run(update_cnt)
        # --prepare data--#
        a_real_ipt = a_data_pool.batch()
        b_real_ipt = b_data_pool.batch()
        a2b_opt, b2a_opt = sess.run([a2b, b2a], feed_dict={a_real: a_real_ipt, b_real: b_real_ipt})
        a2b_sample_ipt = np.array(a2b_pool(list(a2b_opt)))
        b2a_sample_ipt = np.array(b2a_pool(list(b2a_opt)))

        # --train G--#
        g_summary_opt, _ = sess.run([g_summary, g_train_op], feed_dict={a_real: a_real_ipt, b_real: b_real_ipt})
        summary_writer.add_summary(g_summary_opt, it)
        # --train D_b--#
        d_summary_b_opt, _ = sess.run([d_summary_b, d_b_train_op],
                                      feed_dict={b_real: b_real_ipt, a2b_sample: a2b_sample_ipt})
        summary_writer.add_summary(d_summary_b_opt, it)
        # --train D_a--#
        d_summary_a_opt, _ = sess.run([d_summary_a, d_a_train_op],
                                      feed_dict={a_real: a_real_ipt, b2a_sample: b2a_sample_ipt})
        summary_writer.add_summary(d_summary_a_opt, it)
        # --train metric--#
        metric_summary_opt, _ = sess.run([metric_summary, metric_train_op],
                                         feed_dict={a_real: a_real_ipt, b_real: b_real_ipt, a2b: a2b_opt, b2a: b2a_opt})
        summary_writer.add_summary(metric_summary_opt, it)

  • 1
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值