现有的ID-CGAN暂时没有找到tf版本,那就自己来手动复现一下。
首先读下论文中整体的网络结构生成网络由几个密集模块组合,中间加了很多下采样层,最后加了tanh激活函数
原文中对生成网络结构的秒描述如下:
密集模块由3×3的卷积和核组成,中间使用了skip concation,通道数的变化也写明了,那就直接上代码。
def convB(self, x, ch, i,num=None,is_training=True):
with tf.variable_scope('block_{}'.format(num)):
x = conv2d(input_=x, output_dim=ch, kernel_size=3, stride=1, name="conv2d_RRDB_{}_{}".format(num,i))
x = batch_norm(x, is_training=is_training, name='_conv2d_RRDB_{}_{}'.format(num, i))
x = lrelu(x)
return x
def dense_idcgan(self,x,ch,num,count):
with tf.variable_scope('dense_sub{0}'.format(count)):
layers_concat = [x]
for i in range(num):
x = self.convB(x,ch//2,num,i)
layers_concat.append(x)
x =tf.concat(layers_concat,axis=-1)
return x
def generator(self, image, gf_dim=64, reuse=False, name="generator"):
with tf.variable_scope(name, reuse=reuse):
x = conv2d(input_=image, output_dim=gf_dim, kernel_size=3, stride=1, name='g_first_conv_0')
x0_pool = max_pool(x, 2)
x1 = self.dense_idcgan(x0_pool,gf_dim,4,1)
x1_pool = conv2d(input_=x1, output_dim=gf_dim*2, kernel_size=3, stride=2, name='g_first_conv_1')
x2 = self.dense_idcgan(x1_pool,gf_dim*2,4,2)
x2_pool = conv2d(input_=x2, output_dim=gf_dim * 4, kernel_size=3, stride=2, name='g_first_conv_2')
x3 = self.dense_idcgan(x2_pool,gf_dim*4,6,3)
x3_pool = conv2d(input_=x3, output_dim=gf_dim * 8, kernel_size=3, stride=1, name='g_first_conv_3')
x4 = self.dense_idcgan(x3_pool,gf_dim*8,6,4)
x4_pool = conv2d(input_=x4, output_dim=gf_dim *2, kernel_size=3, stride=1, name='g_first_conv_4')
x5 = self.dense_idcgan(x4_pool,gf_dim*2,4,5)
x5_up = tf.layers.conv2d_transpose(x5, 120, 3,2,padding='same',name='g_d1')
x6 = self.dense_idcgan(x5_up, 120, 4, 6)
x6_up = tf.layers.conv2d_transpose(x6, 64, 3,2,padding='same',name='g_d2')
x6_concat = tf.concat([x6_up,x0_pool],-1)
x7 = self.dense_idcgan(x6_concat, gf_dim , 3, 7)
x7_up = tf.layers.conv2d_transpose(x7, 64, 3,2,padding='same',name='g_d3')
x7_concat = tf.concat([x7_up,x],-1)
x8 = self.dense_idcgan(x7_concat, gf_dim, 4, 8)
x8_up = tf.layers.conv2d_transpose(x8, 16, 1,1,padding='same',name='g_d4')
x10 = conv2d(input_=x8_up, output_dim=3, kernel_size=3, stride=1, name='g_first_conv_10')
out = tf.nn.tanh(x10)
return out
后面是判别器
判别网络基于vgg,在最后一个下采样层进入了PAN结构,按照论文的意思直接搭建模型就可以了
def res_dense(self,x_init,ch,num,is_training=True):
with tf.variable_scope('res_dense{}'.format(num)):
x = lrelu(batch_norm(conv2d(input_=x_init, output_dim=ch, kernel_size=4, stride=1, name='res_dense_conv_a_conv3_{}'.format(num)),
is_training=is_training, name='res_dense_conv1_b3_{}'.format(num)))
x = lrelu(batch_norm(conv2d(input_=x, output_dim=ch, kernel_size=4, stride=1,
name='res_dense_conv_b_conv3_{}'.format(num)),
is_training=is_training, name='res_dense_conv2_b3_{}'.format(num)))
x = tf.layers.conv2d_transpose(x, ch, 1,1,padding='same',name='res_dense_conv3_{}'.format(num))
x = max_pool(x,2)
return x
def up_conv_pypool(self, x_init, feature_map, filter, num, is_training=True):
with tf.variable_scope('up_conv_pypool_{}'.format(num)):
x1_up = tf.layers.conv2d_transpose(x_init, filter, 2,2,padding='same',name='g_d_{}'.format(num))
x = lrelu(batch_norm(conv2d(input_=x1_up, output_dim=filter, kernel_size=3, stride=1, name='res_dense_conv1_{}'.format(num)),
is_training=is_training, name='res_dense_b_{}'.format(num)))
return x
def Pyramid_Pool(self,x_init,ch):
with tf.variable_scope('Pyramid_Pool'):
x_list = [x_init]
x1 = max_pool(x_init, 2)
x1_up = self.up_conv_pypool(x1,ch//4,2,0)
x_list.append(x1_up)
x2 = max_pool(x_init, 4)
x2_up = self.up_conv_pypool(x2, ch//8, 4, 1)
x2_up = self.up_conv_pypool(x2_up, ch//4, 4, 2)
x_list.append(x2_up)
x3 = max_pool(x_init, 8)
x3_up = self.up_conv_pypool(x3, ch//16, 8, 3)
x3_up = self.up_conv_pypool(x3_up, ch//8, 8, 4)
x3_up = self.up_conv_pypool(x3_up, ch//4, 8, 5)
x3 = tf.nn.dropout(x3_up, 0.2)
x_list.append(x3)
out = tf.concat(x_list,axis=-1)
return out
def discriminator(self,image,targets, df_dim=64, reuse=False, name="discriminator",is_training=True):
with tf.variable_scope(name,reuse=reuse):
dis_input = tf.concat([image, targets], 3)
h0= self.res_dense(dis_input,df_dim,0,is_training=is_training)
h1 = self.res_dense(h0,df_dim*4,1,is_training=is_training)
h2= self.res_dense(h1,df_dim*8,2,is_training=is_training)
h3= self.res_dense(h2,df_dim,3,is_training=is_training)
h4 = self.Pyramid_Pool(h3,df_dim) ###input_x = (1,16,16,64) 他们写的是(1,14,14,64)
out = tf.nn.sigmoid(h4)
return out
原文的输入是每一个样本扔到网络里面
这个这里我改成了batch_size =4
数据集是从imagnet中获取的,由于算力资源有限,只选取了4个种类作为样本的训练,分别是汽车,羊,鸟类,和斑马,imagenet中每个种类的样本是1300张,所以每个样本用了1000张作为训练,300张作为测试,附上效果图
效果图里面左边是噪声图样本,中间为预测样本,最右边是ground truth 可以看到恢复效果还是可以的 不过原图噪声样本中的某些纹理已经丢失了,所以没有办法很好的恢复,后期在改进