自定义交叉熵损失计算出现nan值问题解决

41 篇文章 10 订阅
10 篇文章 0 订阅

交叉熵损失函数本身的公式比较简单,但是在实际定义的时候需要注意exp(x)函数的溢出问题,exp(x)函数在numpy或者说tensorflow的底层实现上,当x过大的时候会产生溢出,过小的时候直接范围近似值0,所以我们在定义交叉熵损失函数的时候需要注意这一点;

1. 当模型返回的值是sigmoid函数映射过后的值,这里假设输入交叉熵的为x,那么我们计算的就是

 -(y*np.log(x)+(1-y)*np.log(1-x))  # y为标签值,x为预测值,标准的二分类交叉熵损失代码

但是log(0)是没有意义的,如果当输入的x为0,就会出现输出的loss为nan,

出现nan的原因

NaN(not a number),在数学表示上表示一个无法表示的数,这里一般还会有另一个表述inf,inf和nan的不同在于,inf是一个超过浮点表示范围的浮点数(其本质仍然是一个数,只是他无穷大,因此无法用浮点数表示,比如1/0),而nan则一般表示一个非浮点数(比如无理数)。

-y * np.log(a) - (1-y) * np.log(1 - a)

当a = y = 0.0, y * np.log(a) = 0 * -inf = nan
当a = y = 1.0, (1 - y) * np.log(1 - a) = 0 * -inf = nan

出现nan的核心原因是log(0.0) = -inf, 所以a的取值才是关键
而通常情况下a趋向0时, y多数就会等于0, 所以0 * -inf = nan

不过a=0, y=1也是有可能出现的, 就是初始化时, a就与y截然相反.之后就不会再出现了, 因为优化总是往好的方向走.

Tensorflow的处理方法

下面的代码中, 用了三种方法解决这个问题, 并对比了速度和精确度

  • 方法1. 给a加上一个极小值
  • 方法2. 当a接近0时, 给a一个极小值
  • 方法3. 出现nan时, 设置损失为0
# coding=utf-8
import tensorflow as tf
import time

NEAR_0 = 1e-10
ZERO = tf.constant(0.0)

# 正常交叉熵
def cross_entropy(y, a):
    return sess.run(-(y * tf.log(a) + (1 - y) * tf.log(1 - a)))

# 方法一:
def method1(y, a):
    return sess.run(-(y * tf.log(a + NEAR_0) + (1 - y) * tf.log(1 - a + NEAR_0)))

# 方法二:
def method2(y, a):
    return sess.run(-(y * tf.log(nan_to_num(a)) + (1 - y) * tf.log(nan_to_num(1 - a))))

# 方法三:
def method3(y, a):
    return sess.run(nan_to_zero(-(y * tf.log(a) + (1 - y) * tf.log(1 - a))))

# 裁剪n
def nan_to_num(n):
    return tf.clip_by_value(n, NEAR_0, 1)

# 重新设置损失值
def nan_to_zero(c):
    return tf.cond(tf.is_nan(c), lambda: ZERO, lambda: c)


sess = tf.Session()

# 会出现nan的主要有两处
# 1. y和a都是0
# 2. y和a都是1
# 也就是loss很小的时候
# 而y=1, a=0, 或者a=1, y=1, 只有开刚开始优化的时候可能出现, 后面loss都会往好的方向走.

# log(0.0) = -inf 负无穷
print '交叉熵:'
print cross_entropy(0.0, 0.0)  # y * tf.log(a) : 0.0 * -inf = nan
print cross_entropy(1.0, 1.0)  # (1 - y) * tf.log(1 - a) : 0.0 * -inf = nan

print '\n方法一:'
# 方法1: 加上一个接近0的数
print method1(0.0, 0.0)
print method1(1.0, 1.0)

print '\n方法二:'
# 方法2: 当a比1e-10还小时, 等于1e-10
print method2(0.0, 0.0)
print method2(1.0, 1.0)

print '\n方法三:'
# 方法3: 出现nan, 赋值为0
print method3(0.0, 0.0)
print method3(1.0, 1.0)

# 哪种方法更好
# 1. 速度, 方法1, 胜
#    方法1, 每次都加一个数
#    方法2, 每次都要比较大小
#    方法3, 不好说
#    结论, 加一个数计算成本更低, 所有速度更快
start_time = time.time()
for i in xrange(100):
    method1(0.0, 0.0)
duration = time.time() - start_time
print '\n方法1用时: {0}'.format(duration)

start_time = time.time()
for i in xrange(100):
    method2(0.0, 0.0)
duration = time.time() - start_time
print '\n方法2用时: {0}'.format(duration)

start_time = time.time()
for i in xrange(100):
    method3(0.0, 0.0)
duration = time.time() - start_time
print '\n方法3用时: {0}'.format(duration)

# 2. 精确度
#    方法1, 无论什么数, 都要加个1e-10
#    方法2, 只有小于1e-10时, 才会起作用
#    方法3, 出现nan, 才会起作用
#    总结, 理论上方法二, 三肯定是更精确, 但是实际看不到任何差别
print '\n'
print '%.100f' % (0.5 + 1e-10)
print '%.100f' % sess.run(tf.clip_by_value(0.5, 1e-10, 1))

print '\n'
print '%.100f' % method1(1.0, 0.5)
print '%.100f' % method2(1.0, 0.5)
print '%.100f' % method3(1.0, 0.5)

2. 直接把sigmoid的计算也包含在交叉熵损失函数中的时候,这里假设模型的输入还是为x,只不过这个x没有经过sigmoid映射,那么交叉熵损失函数为:

-(y*np.log(sigmoid(x))+(1-y)*np.log(1-sigmoid(x)))

还是会遇到1中写的问题,但是这里我们可以先对这个公式进行优化:

 -(y*log(sigmoid(x))+(1-y)*log(1-sigmoid(x))))
=y*-log(1/1+exp(-x))+(1-y)*-log(exp(-x)/(1+exp(-x)))
=y*log(1+exp(-x))+(1-y)*(-log(exp(-x))+log(1+exp(-x)))
=y*log(1+exp(-x))+(1-y)*x+(1-y)log(1+exp(-x)))
=log(1+exp(-x))+(1-y)*x

这样就简单很多了,但是我们最上面说了,exp(x)函数在x很大的时候会溢出,所以上面的公式中,因为输入exp()的-x,所以x很小的时候会溢出。因此针对x<0的时候单独再优化一下,按如下:

  log(1+exp(-x))+(1-y)*x
=log(1+exp(-x))+log(exp(x))-y*x
=log(exp(x)+1)-y*x

这样的话就可以跳过因为输入exp()函数的值过大而溢出的问题,因此输入exp()的值很小的时候是会趋向于0的,输出的值直接约等于0。

所以把上面两个合并起来就可以得到最终的公式:

max(x,0)-x*y+log(1+exp(-abs(x)))

 

 

 

 

 

 

 

评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值