如何在tensorflow2下重写CGAN模型的fit方法并绘出模型OP图

        昨天花了一整天的时间找有没有什么合适的神经网络模型可视化,具体的可视化工具可以参考ashishpatel26/Tools-to-Design-or-Visualize-Architecture-of-Neural-Network: Tools to Design or Visualize Architecture of Neural Network (github.com),但是绝大部分的可视化工具需要重新构建编写,而不能直接调用Keras定义好的模型,这又会使代码看起来非常冗余,所以对比下来Keras VisualizationTensorBoard是最符合这个要求的,但是Keras Visualization比较难画出复杂的模型图,尤其是现在的神经网络模型百花齐放,复杂程度也比当初的CNN,LSTM,MPL高不少。TensorBoard,是非常不错的选择,但是想在TF2中直接调用模型绘图也比较麻烦,最好的方法是在模型的fit()函数中调用callback,然后再绘制出来,但是又有了新的问题,直接调用fit函数的适用性范围很小,例如GAN就没有办法适用fit函数(之前尝试过这种方法,模型非常难训练,GAN极不推荐),所以有什么方法可以在TensorBoard中绘制出GAN的训练图像呢?这是本文需要解决的问题!

——关键在于如何构建适用于GAN的fit函数

——即重写fit函数,参照自定义 Model.fit 的内容  |  TensorFlow Core (google.cn)

接来下进行实操:

大家可以关注一下这位大佬,我从他这学到很多,非常厉害的博士!ICCV 2021

LynnHo (Zhenliang He) (github.com)

首先这个模型是image2image的CGAN的generator部分

import tensorflow as tf
import tensorflow.keras as keras

def _get_norm_layer(norm):
    if norm == 'none':
        return lambda: lambda x: x
    elif norm == 'batch_norm':
        return keras.layers.BatchNormalization
    elif norm == 'instance_norm':
        return tfa.layers.InstanceNormalization
    elif norm == 'layer_norm':
        return keras.layers.LayerNormalization


def ResnetGenerator(input_shape=(227, 227, 3),
                    output_channels=3,
                    dim=64,
                    n_downsamplings=2,
                    n_blocks=9,
                    norm='instance_norm'):
    Norm = _get_norm_layer(norm)

    def _residual_block(x):
        dim = x.shape[-1]
        h = x

        h = tf.pad(h, [[0, 0], [1, 1], [1, 1], [0, 0]], mode='REFLECT')
        h = keras.layers.Conv2D(dim, 3, padding='valid', use_bias=False)(h)
        h = Norm()(h)
        h = tf.nn.relu(h)

        h = tf.pad(h, [[0, 0], [1, 1], [1, 1], [0, 0]], mode='REFLECT')
        h = keras.layers.Conv2D(dim, 3, padding='valid', use_bias=False)(h)
        h = Norm()(h)

        return keras.layers.add([x, h])

    # 0
    h = inputs = keras.Input(shape=input_shape)

    # 1
    h = tf.pad(h, [[0, 0], [3, 3], [3, 3], [0, 0]], mode='REFLECT')
    h = keras.layers.Conv2D(dim, 7, padding='valid', use_bias=False)(h)
    h = Norm()(h)
    h = tf.nn.relu(h)

    # 2
    for _ in range(n_downsamplings):
        dim *= 2
        h = keras.layers.Conv2D(dim, 3, strides=2, padding='same', use_bias=False)(h)
        h = Norm()(h)
        h = tf.nn.relu(h)

    # 3
    for _ in range(n_blocks):
        h = _residual_block(h)

    # 4
    for _ in range(n_downsamplings):
        dim //= 2
        h = keras.layers.Conv2DTranspose(dim, 3, strides=2, padding='same', use_bias=False)(h)
        h = Norm()(h)
        h = tf.nn.relu(h)

    # 5
    h = tf.pad(h, [[0, 0], [3, 3], [3, 3], [0, 0]], mode='REFLECT')
    h = keras.layers.Conv2D(output_channels, 8, padding='valid')(h)
    h = tf.tanh(h)

    return keras.Model(inputs=inputs, outputs=h)

然后是discriminator部分


def ConvDiscriminator(input_shape=(227, 227, 3),
                      dim=64,
                      n_downsamplings=3,
                      norm='instance_norm'):
    dim_ = dim
    Norm = _get_norm_layer(norm)

    # 0
    h = inputs = keras.Input(shape=input_shape)

    # 1
    h = keras.layers.Conv2D(dim, 4, strides=2, padding='same')(h)
    h = tf.nn.leaky_relu(h, alpha=0.2)

    for _ in range(n_downsamplings - 1):
        dim = min(dim * 2, dim_ * 8)
        h = keras.layers.Conv2D(dim, 4, strides=2, padding='same', use_bias=False)(h)
        h = Norm()(h)
        h = tf.nn.leaky_relu(h, alpha=0.2)

    # 2
    dim = min(dim * 2, dim_ * 8)
    h = keras.layers.Conv2D(dim, 4, strides=1, padding='same', use_bias=False)(h)
    h = Norm()(h)
    h = tf.nn.leaky_relu(h, alpha=0.2)

    # 3
    h = keras.layers.Conv2D(1, 4, strides=1, padding='same')(h)

    return keras.Model(inputs=inputs, outputs=h)

模型构建

generator = ResnetGenerator()
discriminator = ConvDiscriminator()

设置优化器与损失函数


generator_optimizer = tf.keras.optimizers.Adam(2e-3, beta_1=0.5, decay=0.2)
discriminator_optimizer = tf.keras.optimizers.Adam(2e-3, beta_1=0.5, decay=0.2)

mse = tf.keras.losses.MeanSquaredError()


def discriminator_loss(real_output, fake_output):
    real_loss = mse(tf.ones_like(real_output), real_output)
    fake_loss = mse(tf.zeros_like(fake_output), fake_output)

    return real_loss + fake_loss


def generator_loss(fake_output):
    fake_loss = mse(tf.ones_like(fake_output), fake_output)

    return fake_loss

重写模型的fit函数(重点)


class Image_CGAN(keras.Model):
    def __init__(self, discriminator, generator, latent_dim=None):
        super(Image_CGAN, self).__init__()
        self.discriminator = discriminator
        self.generator = generator
        self.latent_dim = latent_dim

    def compile(self, d_optimizer, g_optimizer, d_loss_fn, g_loss_fn):
        super(Image_CGAN, self).compile()
        self.d_optimizer = d_optimizer
        self.g_optimizer = g_optimizer
        self.d_loss_fn = d_loss_fn
        self.g_loss_fn = g_loss_fn

    def train_step(self, zip_image):

        with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:

            if isinstance(zip_image, tuple):
                positive_image = zip_image[0]
                negative_image = zip_image[1]
                # batch_size = tf.shape(positive_image)[0]
            else:
                real_image = zip_image
                batch_size = tf.shape(real_image)[0]
                # random_latent_vectors = tf.random.normal(shape=(batch_size, self.latent_dim))

            generated_images = generator(negative_image, training=True)

            real_output = discriminator(positive_image, training=True)
            fake_output = discriminator(generated_images, training=True)

            gen_loss = self.g_loss_fn(fake_output)
            disc_loss = self.d_loss_fn(real_output, fake_output)

        gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)
        gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)

        generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
        discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))

        return {"d_loss": disc_loss, "g_loss": gen_loss}

模型编译

# 添加上callbacks

logdir="logs/fit/" + datetime.now().strftime("%Y%m%d-%H%M%S")
tensorboard_callback = keras.callbacks.TensorBoard(log_dir=logdir)


# 模型设置——编译——训练
gan = Image_CGAN(discriminator=discriminator, generator=generator)
gan.compile(d_optimizer=discriminator_optimizer, g_optimizer=generator_optimizer,
            d_loss_fn=discriminator_loss, g_loss_fn=generator_loss)
gan.fit(dataset, epochs=1, callbacks=[tensorboard_callback])

然后就能想fit CNN那样训练GAN了

 在TensorBoard中就能看到OP graph了

值得一提的是,您可以不用重写fit,直接使用model.predect()或model.evaluate()函数,然后调用callbacks也能获取图像,但是那样的图像就不是OP Graph了, 而是Conceptual graph了,就只绘制模型的整体架构,而不能突出其中的数据操作与模型联系。

还有就是,Pytorch的 TensorBoardX也能完成这些工作,而且相较于TF2更加清晰,这与它的模型编写思路有关,如果能用Pytorch的话更好,如果您因为TF2而疑惑这些问题,TF2同样也提出了解决方案,还是不错的,希望这能帮助到你。

说到这就差不多了,这是第一篇博客,分享的感觉还不错~

(抛砖引玉)fit的重写不仅能绘制图像,还能使用与fit函数紧密相关的很多功能,如果您能重写fit,那之前的TF2的很多束缚都能迎刃而解,很重要的!然后就是CSDN的TF2的新手入门教程很多,但很小白,这既是好事也是坏事,容易给新学习的人造成误解,当遇到不满足教程设置条件时就手足无措,希望以后写小白教程多保留一些拓展的空间,毕竟深度学习是一个开放的领域。

  • 2
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值