DCGAN论文改进之处+简化代码

一、论文亮点

论文地址:https://arxiv.org/abs/1511.06434

论文第三章讲了改进点,如下:

  1. 将pooling层替换成带strides的卷积层。判别器中就是带strides的卷积,生成器中,论文中说是fractional-strided
  2. TF中用的conv2d_transpose,总之是上采样。
  3. 消除顶层卷积特征中的全连接层,为了实现更深的网络。顶层特征指的是生成器的输入,以及判别器的输出。
  4. 使用BatchNorm。直接对所有层使用batchnorm会导致震荡和不稳定。所以在生成器的输出层和辨别其的输入层不用。
  5. 在生成器中使用ReLU激活,除了输出层,用的是tanh激活。辨别器使用的是leakyReLU激活,尤其对于高分辨率建模。 

第四章讲了训练超参数细节:

  1. 对于输入的训练图,预处理缩放到tanh的范围[-1,1];
  2. 模型用SGD训练,batchsize为128;
  3. 权重用零中心正态分布、标准偏差0.02初始化;
  4. LeakyReLU的leak rate设为0.2;
  5. 如果使用Adam加速训练,推荐的0.001学习率太高,改用0.0002;
  6. 另外,momentum beta1用推荐的0.9导致振荡不稳定,降到0.5会稳定很多。

 在LSUN(Large-scale Scene Understanding (LSUN))场景下,使用的生成器网络结构如下:

PS:论文中总共在三个数据集上做了测试:Large-scale Scene Understanding (LSUN),Imagenet-1k,a newly assembled Faces dataset。实验结果什么的我就不贴了。

二、简化代码解读

git上有份很好的demo,可以试试

https://github.com/carpedm20/DCGAN-tensorflow

参考这个demo的代码,我把DCGAN的关键的生成器、判别器、loss、train用最简单的代码写出来,先整理下思路,然后就放飞自我随便修改做各种尝试了,毕竟后面还有很多很多的CNN GAN网络,WGAN,WGAN-GP,LSGAN,等等等

判别器和生成器使用的函数:

其中的很多函数,新版的tf都是有的,看看就好

# 输出shape= [batch_size, output_size],可能也会输出w和b
def linear(input_, output_size, scope=None, stddev=0.02, bias_start=0.0, with_w=False):
    shape = input_.get_shape().as_list()  # [batch, dim]

    with tf.variable_scope(scope or "Linear"):
        try:
            matrix = tf.get_variable("Matrix", [shape[1], output_size], tf.float32,
                                     tf.random_normal_initializer(stddev=stddev))
            # matrix shape = [dim, output_size]
        except ValueError as err:
            msg = "NOTE: Usually, this is due to an issue with the image dimensions.  Did you correctly set '--crop' or '--input_height' or '--output_height'?"
            err.args = err.args + (msg,)
            raise
        bias = tf.get_variable("bias", [output_size],
                               initializer=tf.constant_initializer(bias_start))
        # bias shape = [output_size]
        # tf.matmul shape = [batch, output_size]
        if with_w:
            return tf.matmul(input_, matrix) + bias, matrix, bias
        else:
            return tf.matmul(input_, matrix) + bias

# kernel_size=5,strides = 2, padding='same',带bias的卷积
def conv2d(input_, output_dim,
           k_h=5, k_w=5, d_h=2, d_w=2, stddev=0.02,
           name="conv2d"):
    with tf.variable_scope(name):
        w = tf.get_variable('w', [k_h, k_w, input_.get_shape()[-1], output_dim],
                            initializer=tf.truncated_normal_initializer(stddev=stddev))
        conv = tf.nn.conv2d(input_, w, strides=[1, d_h, d_w, 1], padding='SAME')

        biases = tf.get_variable('biases', [output_dim],  initializer=tf.constant_initializer(0.0))
        conv = tf.reshape(tf.nn.bias_add(conv, biases), conv.get_shape())

        return conv


# kernel_size=5,strides = 2, padding='same',带bias的反卷积(转置卷积)
def deconv2d(input_, output_shape,
             k_h=5, k_w=5, d_h=2, d_w=2, stddev=0.02,
             name="deconv2d", with_w=False):
    with tf.variable_scope(name):
        # filter : [height, width, output_channels, in_channels]
        w = tf.get_variable('w', [k_h, k_w, output_shape[-1], input_.get_shape()[-1]],
                            initializer=tf.random_normal_initializer(stddev=stddev))

        try:
            deconv = tf.nn.conv2d_transpose(input_, w, output_shape=output_shape,
                                            strides=[1, d_h, d_w, 1])

        # Support for verisons of TensorFlow before 0.7.0
        except AttributeError:
            deconv = tf.nn.deconv2d(input_, w, output_shape=output_shape,
                                    strides=[1, d_h, d_w, 1])

        biases = tf.get_variable('biases', [output_shape[-1]], initializer=tf.constant_initializer(0.0))
        deconv = tf.reshape(tf.nn.bias_add(deconv, biases), deconv.get_shape())

        if with_w:
            return deconv, w, biases
        else:
            return deconv



def conv_out_size_same(size, stride):
    return int(math.ceil(float(size) / float(stride)))


def bn(x,epsilon=1e-5, momentum=0.9, name="batch_norm"):
    tf.layers.batch_normalization(inputs=x,momentum=momentum,epsilon=epsilon,name=name)


def lrelu(x, leak=0.2, name="lrelu"):
    return tf.maximum(x, leak * x,name)


辨别器

# 假设输入的batch_size = 16,第一层输出通道df_dim = 64
# 输入的image的shape=[batch_size,96,96,3]
def discriminator(image, reuse=False):
    with tf.variable_scope("discriminator") as scope:
        # 辨别器可能会被使用多次,所以可能需要resue=True
        if reuse:
            scope.reuse_variables()

        # lrelu、bn、conv2d、linear的代码后面会有
        h0 = lrelu(conv2d(image, df_dim, name='d_h0_conv'))
        h1 = lrelu(bn(conv2d(h0, df_dim * 2, name='d_h1_conv')))
        h2 = lrelu(bn(conv2d(h1, df_dim * 4, name='d_h2_conv')))
        h3 = lrelu(bn(conv2d(h2, df_dim * 8, name='d_h3_conv')))

        # h3的shape=[batchsize,6,6,512]
        h4 = linear(tf.reshape(h3, [batch_size, -1]), 1, 'd_h4_lin')  # similar as full connect
        # h4 的shape=[batchsize,1]
        return tf.nn.sigmoid(h4), h4

生成器

# output_height, output_width都是96, gf_dim = 64
# 输入的z的shape=[batch_size,100]
def generator(z):
    with tf.variable_scope("generator") as scope:
        s_h, s_w = output_height, output_width  # [96,96]
        s_h2, s_w2 = conv_out_size_same(s_h, 2), conv_out_size_same(s_w, 2)  # [48,48]
        s_h4, s_w4 = conv_out_size_same(s_h2, 2), conv_out_size_same(s_w2, 2)  # [24,24]
        s_h8, s_w8 = conv_out_size_same(s_h4, 2), conv_out_size_same(s_w4, 2)  # [12,12]
        s_h16, s_w16 = conv_out_size_same(s_h8, 2), conv_out_size_same(s_w8, 2)  # [6,6]

        
        z_, h0_w, h0_b = linear(z, gf_dim * 8 * s_h16 * s_w16, 'g_h0_lin', with_w=True)
        # z_的shape=[batch_size,512*6*6]

        h0 = tf.reshape(z_, [-1, s_h16, s_w16, gf_dim * 8])
        h0 = tf.nn.relu(bn(h0))

        h1 = deconv2d(h0, [batch_size, s_h8, s_w8, gf_dim * 4], name='g_h1', with_w=True)
        h1 = tf.nn.relu(bn(h1))

        h2 = deconv2d(h1, [batch_size, s_h4, s_w4, gf_dim * 2], name='g_h2', with_w=True)
        h2 = tf.nn.relu(bn(h2))

        h3 = deconv2d(h2, [batch_size, s_h2, s_w2, gf_dim * 1], name='g_h3', with_w=True)
        h3 = tf.nn.relu(bn(h3))

        h4 = deconv2d(h3, [batch_size, s_h, s_w, c_dim], name='g_h4', with_w=True)

        # h4的shape=[batch_size, 96, 96, 3]
        return tf.nn.tanh(h4)

loss

# 生成的假图
G = self.generator(z)
# 真图判别结果
D, D_logits = discriminator(inputs, reuse=False)
# 假图判别结果
D_, D_logits_ = discriminator(G, reuse=True)

# 判别器loss: 真图-真
d_loss_real = tf.reduce_mean(
            sigmoid_cross_entropy_with_logits(D_logits, tf.ones_like(D)))
# 判别器loss: 假图-假
d_loss_fake = tf.reduce_mean(
            sigmoid_cross_entropy_with_logits(D_logits_, tf.zeros_like(D_)))
# 判别器总loss
d_loss =d_loss_real + d_loss_fake

# 生成器loss: 假图-真
g_loss = tf.reduce_mean(
            sigmoid_cross_entropy_with_logits(D_logits_, tf.ones_like(D_)))

optimizer

这里,生成器的optimizer只优化生成器里面的参数;同理,判别器只优化判别器参数

t_vars = tf.trainable_variables()
# 判别器参数变量
d_vars = [var for var in t_vars if 'd_' in var.name]
# 生成器参数变量
g_vars = [var for var in t_vars if 'g_' in var.name]
# 只优化判别器
d_optim = tf.train.AdamOptimizer(learning_rate, beta1=beta1).minimize(d_loss, var_list=d_vars)
#只优化生成器
g_optim = tf.train.AdamOptimizer(learning_rate, beta1=beta1).minimize(g_loss, var_list=g_vars)

train

with tf.Session() as sess:
    # 生成batch_size个z
    batch_z = np.random.uniform(-1, 1, [batch_size, z_dim]).astype(np.float32)
    # 当然还有读入batch_size个image,这里就不写了
    # 然后feed给sess
    sess.run(d_optim,feed_dict={inputs: batch_images,z: batch_z})  # 更新判别器
    sess.run(g_optim,feed_dict={z: batch_z})  # 更新生成器

  • 0
    点赞
  • 22
    收藏
    觉得还不错? 一键收藏
  • 6
    评论
改进dcgan的刺绣图像修复研究代码主要包括以下几个模块:数据预处理、DCGAN模型构建、训练和评估。 首先,对刺绣图像数据进行预处理,包括数据的清洗、归一化和切分。清洗数据是为了去除噪声和无效信息,同时对数据进行归一化处理,将数据转化成模型可以接受的格式,最后对数据进行切分,划分成训练集和测试集。 其次,构建改进DCGAN模型,包括生成器和判别器两个部分。生成器负责生成缺失部分的图像,而判别器则负责鉴别生成的图像和真实的图像。相比原始的DCGAN改进的模型可能包括更多的层或者使用不同的激活函数和损失函数,以提高修复效果。 接着,进行模型的训练和优化。通过将切分好的训练集输入到DCGAN模型中进行训练,不断调整模型的参数和超参数,直到模型收敛并得到较好的修复效果。同时可以使用一些优化算法如学习率衰减等来提高训练效果。 最后,对模型进行评估和测试。使用测试集对训练好的模型进行测试,评估修复效果,并根据评估结果对模型进行进一步的调优和改进。同时可以进行定量分析和定性分析,对修复效果进行综合评价。 综上所述,基于改进dcgan的刺绣图像修复的研究代码主要包括数据预处理、DCGAN模型构建、训练和优化以及评估和测试这几个主要模块,通过这些模块的合理设计和实现,可以得到较好的刺绣图像修复效果。
评论 6
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值