tensorflow 中交叉熵损失函数的计算方法
下列代码中
- l1 是自定义计算方式
- l2 是利用 tf.nn.softmax_cross_entropy_with_logits_v2 输入的真实标签需要onehot处理
- l3 是利用 tf.nn.sparse_softmax_cross_entropy_with_logits
输入的真实标签是原始多分类中的为1的类别的索引 - l4 是losses方法直接计算出最终loss,相比于上述方法,多求了一个平均值。但是在函数内部,这个loss被收入到 GraphKeys.LOSSES 中。
一般不建议用自定义函数来实现交叉熵,因为会有边界条件的判断来自己设置,比如***“ z = - tf.log(y + 1e-12) ”*** ,这里加了一个极小值,保证log函数的输入不是0. 类似的情况在其他给定的函数api中都有解决方案,不用过多担心。
import tensorflow as tf
x = tf.constant([[0.1,0.5,1.1], [10,3,2.]])
label = tf.constant([2,0])
label_onehot = tf.one_hot(label, depth=3)
y = tf.nn.softmax(x)
z = - tf.log(y + 1e-12)
l1 = tf.reduce_sum(label_onehot * z, axis=1)
l2 = tf.nn.softmax_cross_entropy_with_logits_v2(labels=label_onehot, logits=x)
l3 = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=label, logits=x)
l1_mean = tf.reduce_mean(l1)
l2_mean = tf.reduce_mean(l2)
l3_mean = tf.reduce_mean(l3)
l4 = tf.losses.softmax_cross_entropy(onehot_labels=label_onehot, logits=x) # 额外在ops.GraphKeys.LOSSES加入这个loss
with tf.Session() as sess:
print("l1:", sess.run(l1), sess.run(l1_mean))
print("l2:", sess.run(l2), sess.run(l2_mean))
print("l3:", sess.run(l3), sess.run(l3_mean))