Tensrfow GAN Discriminator 如何使用hinge loss训练

1 篇文章 0 订阅
1 篇文章 0 订阅

hinge loss

核心点:网络的输出要确保是[-1,1]范围
之前一直用cross entrype loss这一点没有台注意,所以之前一直没写对!

  • hinge loss 核心代码
    def Hinge_loss(pos, neg, name='Hinge_loss'):
        with tf.variable_scope(name):
            d_loss = tf.reduce_mean(tf.nn.relu(1.0 - pos)) + tf.reduce_mean(tf.nn.relu(1.0 + neg))
            g_loss = -tf.reduce_mean(neg)
        return g_loss, d_loss

pos是正样本,neg是负样本,d_loss返回的loss适用于更新Discriminator网络,g_loss用于Generator网络。
最理想的情况如下:

–对于D net:pos里面的值全部为1,neg里面的值全部为-1

–对于G net:neg里面的值全部为1

总结: 之前的网络输出是sigmoid导致g_loss一直是个负值,训练一直出错。

  • 网络的输
# x = tf.nn.sigmoid(x)
# solution 1
x = leaky_relu(x)
x = tf.clip_by_value(x, -1., 1.)
# solution 2
x = tf.nn.tanh(x)

– 将Dnet的最后一层输出由sigmoid转变成leak_relu + clip的形式

–也可以换成tanh,缺点是缺少-1和1两个值

  • 1
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 4
    评论
评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值