tf.keras实现ESRGAN相对平均判别器RaGAN损失

本文介绍了如何在TensorFlow中定义自定义损失函数,特别是针对ESRGAN(增强超分辨率生成对抗网络)中的相对鉴别器损失。通过直接定义损失函数,作者展示了如何计算并实现鉴别器和生成器的损失,涉及到关键步骤包括计算平均值、sigmoid激活以及交叉熵的变形式。文章强调了理解损失函数计算维度的重要性,并提供了详细的代码示例。
摘要由CSDN通过智能技术生成

引言

tf.keras.losses中的损失函数不够用的时候,就需要我们自己来定义一部分损失函数,来达到我们的需求。

方法

有两种方法来定义我们的损失函数,第一种是直接定义,第二种是子类化tf.keras.losses.Loss。我们来介绍第一种。

损失函数

我们要实现ESRGAN中的这种raletivistic discirminator的损失函数。下面是原文:
在这里插入图片描述
首先看到,这种鉴别器分别对fake和real作为输入,输出结果。然后,以Dra(xr, xf)为例,首先求取fake的预测的一个batch的平均值,然后用real预测去减这个平均值,最后求一个sigmoid,作为Dra(xr, xf)。
得出这个结果之后,显然Discriminator的想法是让Dra(xr, xf)最大化,由于经过了sigmoid激活,等价于让Dra(xr, xf)接近于1。同时,让Dra(xf, xr)最小化,由于经过sigmoid激活,等价于让其接近于0。
观察这个形式,发现就是交叉熵的形式,可以计算出两个Dra的形式之后,调用keras的交叉熵进行计算,我们为了更加细化,不调用交叉熵函数。

    def discriminator_loss(real_output, fake_output):
        Ra_loss_rf = tf.math.sigmoid((real_output) - tf.math.reduce_mean(fake_output, axis = 0))
        Ra_loss_fr = tf.math.sigmoid((fake_output) - tf.math.reduce_mean(real_output, axis = 0))
        L_Ra_d = - tf.math.reduce_mean(tf.math.log(Ra_loss_rf)) - tf.math.reduce_mean(tf.math.log(1- Ra_loss_fr))
        return L_Ra_d
    
    def generator_adversarial_loss(real_output, fake_output):
        Ra_loss_rf = tf.math.sigmoid((real_output) - tf.math.reduce_mean(fake_output, axis = 0))
        Ra_loss_fr = tf.math.sigmoid((fake_output) - tf.math.reduce_mean(real_output, axis = 0))
        L_Ra_g = - tf.math.reduce_mean(tf.math.log(1 - Ra_loss_rf)) - tf.math.reduce_mean(tf.math.log(Ra_loss_fr))
        return L_Ra_g

就是这个形式了,大家和上面论文比对,很容易看出来是怎么回事。这里面有几个要点想和大家分享。

  1. tf.reduce_mean这个函数,注意在Ra_loss_rf计算的地方,我指定了axis=0,这里就是说要在batch的维度上求平均,论文中是这样要求的。如果不指定这个axis = 0,则tf.reduce_mean这个函数会把所有的元素全部加起来求平均,则与论文要求不符。
  2. 考虑损失函数的写法的时候,要从四维的角度去考虑(batch, height, width, channel)的角度去考虑,尤其是包含tf.math.reduce_mean函数的时候(注意,tf.reduce_meantf.math.reduce_mean是一样的。)
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值