Tensrfow GAN Discriminator 如何使用hinge loss训练

hinge loss

核心点:网络的输出要确保是[-1,1]范围
之前一直用cross entrype loss这一点没有台注意,所以之前一直没写对!

  • hinge loss 核心代码
    def Hinge_loss(pos, neg, name='Hinge_loss'):
        with tf.variable_scope(name):
            d_loss = tf.reduce_mean(tf.nn.relu(1.0 - pos)) + tf.reduce_mean(tf.nn.relu(1.0 + neg))
            g_loss = -tf.reduce_mean(neg)
        return g_loss, d_loss

pos是正样本,neg是负样本,d_loss返回的loss适用于更新Discriminator网络,g_loss用于Generator网络。
最理想的情况如下:

–对于D net:pos里面的值全部为1,neg里面的值全部为-1

–对于G net:neg里面的值全部为1

总结: 之前的网络输出是sigmoid导致g_loss一直是个负值,训练一直出错。

  • 网络的输
# x = tf.nn.sigmoid(x)
# solution 1
x = leaky_relu(x)
x = tf.clip_by_value(x, -1., 1.)
# solution 2
x = tf.nn.tanh(x)

– 将Dnet的最后一层输出由sigmoid转变成leak_relu + clip的形式

–也可以换成tanh,缺点是缺少-1和1两个值

发布了49 篇原创文章 · 获赞 14 · 访问量 19万+
展开阅读全文

没有更多推荐了,返回首页

©️2019 CSDN 皮肤主题: 大白 设计师: CSDN官方博客

分享到微信朋友圈

×

扫一扫,手机浏览