图像恢复_ID-CGAN的tf版本复现

现有的ID-CGAN暂时没有找到tf版本,那就自己来手动复现一下。
首先读下论文中整体的网络结构生成网络由几个密集模块组合,中间加了很多下采样层,最后加了tanh激活函数

原文中对生成网络结构的秒描述如下:
密集模块由3×3的卷积和核组成,中间使用了skip concation,通道数的变化也写明了,那就直接上代码。在这里插入图片描述
在这里插入图片描述

    def convB(self, x, ch, i,num=None,is_training=True):
        with tf.variable_scope('block_{}'.format(num)):
            x = conv2d(input_=x, output_dim=ch, kernel_size=3, stride=1, name="conv2d_RRDB_{}_{}".format(num,i))
            x = batch_norm(x, is_training=is_training, name='_conv2d_RRDB_{}_{}'.format(num, i))
            x = lrelu(x)
        return x


    def dense_idcgan(self,x,ch,num,count):
        with tf.variable_scope('dense_sub{0}'.format(count)):
            layers_concat = [x]
            for i in range(num):
                x = self.convB(x,ch//2,num,i)
                layers_concat.append(x)
                x =tf.concat(layers_concat,axis=-1)
        return x

    def generator(self, image, gf_dim=64, reuse=False, name="generator"):
        with tf.variable_scope(name, reuse=reuse):

            x = conv2d(input_=image, output_dim=gf_dim, kernel_size=3, stride=1, name='g_first_conv_0')
            x0_pool = max_pool(x, 2)

            x1 = self.dense_idcgan(x0_pool,gf_dim,4,1)
            x1_pool = conv2d(input_=x1, output_dim=gf_dim*2, kernel_size=3, stride=2, name='g_first_conv_1')


            x2 = self.dense_idcgan(x1_pool,gf_dim*2,4,2)
            x2_pool = conv2d(input_=x2, output_dim=gf_dim * 4, kernel_size=3, stride=2, name='g_first_conv_2')


            x3 = self.dense_idcgan(x2_pool,gf_dim*4,6,3)
            x3_pool = conv2d(input_=x3, output_dim=gf_dim * 8, kernel_size=3, stride=1, name='g_first_conv_3')


            x4 = self.dense_idcgan(x3_pool,gf_dim*8,6,4)
            x4_pool = conv2d(input_=x4, output_dim=gf_dim *2, kernel_size=3, stride=1, name='g_first_conv_4')

            x5 =  self.dense_idcgan(x4_pool,gf_dim*2,4,5)

            x5_up = tf.layers.conv2d_transpose(x5, 120, 3,2,padding='same',name='g_d1')

            x6 = self.dense_idcgan(x5_up, 120, 4, 6)

            x6_up = tf.layers.conv2d_transpose(x6, 64, 3,2,padding='same',name='g_d2')
            x6_concat = tf.concat([x6_up,x0_pool],-1)

            x7 = self.dense_idcgan(x6_concat, gf_dim , 3, 7)

            x7_up = tf.layers.conv2d_transpose(x7, 64, 3,2,padding='same',name='g_d3')
            x7_concat = tf.concat([x7_up,x],-1)

            x8 = self.dense_idcgan(x7_concat, gf_dim, 4, 8)
            x8_up = tf.layers.conv2d_transpose(x8, 16, 1,1,padding='same',name='g_d4')

            x10 = conv2d(input_=x8_up, output_dim=3, kernel_size=3, stride=1, name='g_first_conv_10')
            out = tf.nn.tanh(x10)
            return out

后面是判别器
判别网络基于vgg,在最后一个下采样层进入了PAN结构,按照论文的意思直接搭建模型就可以了
在这里插入图片描述
在这里插入图片描述

    def res_dense(self,x_init,ch,num,is_training=True):
        with tf.variable_scope('res_dense{}'.format(num)):
            x = lrelu(batch_norm(conv2d(input_=x_init, output_dim=ch, kernel_size=4, stride=1, name='res_dense_conv_a_conv3_{}'.format(num)),
                                 is_training=is_training, name='res_dense_conv1_b3_{}'.format(num)))
            x = lrelu(batch_norm(conv2d(input_=x, output_dim=ch, kernel_size=4, stride=1,
                                        name='res_dense_conv_b_conv3_{}'.format(num)),
                                 is_training=is_training, name='res_dense_conv2_b3_{}'.format(num)))
            x = tf.layers.conv2d_transpose(x, ch, 1,1,padding='same',name='res_dense_conv3_{}'.format(num))
            x = max_pool(x,2)
        return x

    def up_conv_pypool(self, x_init, feature_map, filter, num, is_training=True):
        with tf.variable_scope('up_conv_pypool_{}'.format(num)):
            x1_up = tf.layers.conv2d_transpose(x_init, filter, 2,2,padding='same',name='g_d_{}'.format(num))
            x = lrelu(batch_norm(conv2d(input_=x1_up, output_dim=filter, kernel_size=3, stride=1, name='res_dense_conv1_{}'.format(num)),
                                  is_training=is_training, name='res_dense_b_{}'.format(num)))
        return  x


    def Pyramid_Pool(self,x_init,ch):
        with tf.variable_scope('Pyramid_Pool'):
            x_list = [x_init]
            x1 = max_pool(x_init, 2)
            x1_up = self.up_conv_pypool(x1,ch//4,2,0)
            x_list.append(x1_up)
            x2 = max_pool(x_init, 4)
            x2_up = self.up_conv_pypool(x2, ch//8, 4, 1)
            x2_up = self.up_conv_pypool(x2_up, ch//4, 4, 2)
            x_list.append(x2_up)
            x3 = max_pool(x_init, 8)
            x3_up = self.up_conv_pypool(x3, ch//16, 8, 3)
            x3_up = self.up_conv_pypool(x3_up, ch//8, 8, 4)
            x3_up = self.up_conv_pypool(x3_up, ch//4, 8, 5)
            x3 = tf.nn.dropout(x3_up, 0.2)
            x_list.append(x3)
            out = tf.concat(x_list,axis=-1)
        return out

    def discriminator(self,image,targets, df_dim=64, reuse=False, name="discriminator",is_training=True):
        with tf.variable_scope(name,reuse=reuse):
            dis_input = tf.concat([image, targets], 3)
            h0= self.res_dense(dis_input,df_dim,0,is_training=is_training)
            h1 = self.res_dense(h0,df_dim*4,1,is_training=is_training)
            h2= self.res_dense(h1,df_dim*8,2,is_training=is_training)
            h3= self.res_dense(h2,df_dim,3,is_training=is_training)
            h4 = self.Pyramid_Pool(h3,df_dim)   ###input_x = (1,16,16,64)   他们写的是(1,14,14,64)
            out = tf.nn.sigmoid(h4)
            return out

原文的输入是每一个样本扔到网络里面
这个这里我改成了batch_size =4
数据集是从imagnet中获取的,由于算力资源有限,只选取了4个种类作为样本的训练,分别是汽车,羊,鸟类,和斑马,imagenet中每个种类的样本是1300张,所以每个样本用了1000张作为训练,300张作为测试,附上效果图
效果图里面左边是噪声图样本,中间为预测样本,最右边是ground truth 可以看到恢复效果还是可以的 不过原图噪声样本中的某些纹理已经丢失了,所以没有办法很好的恢复,后期在改进
在这里插入图片描述

  • 2
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 4
    评论
评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值