TF:tf.losses.sparse_softmax_cross_entropy

tf.losses.sparse_softmax_cross_entropy(logits=logits, labels=labels)等价与

tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=labels) 加 tf.reduce_mean。

它们的参数维度:

  1. logits: [batch_size, num_classes]
  2. labels: [batch_size,]

计算过程:

batch_size=1
num_classes=3
logits = tf.constant([3,1,-3], shape=[batch_size, num_classes], dtype=tf.float32)
labels = tf.constant([2], shape=[batch_size,], dtype=tf.int32)

logits是二维列表,labels是一维列表。batch_size为1时,输入x得到x对应3个类别的logit是[3,1,-3],x的实际标签是2。

求交叉熵可以分为两步:

  1. 求softmax:
    softmax = tf.nn.softmax(logits)#[[0.87887824 0.11894324 0.00217852]]

    其中,tf.nn.softmax的计算步骤是,对列表[3,1,-3],S_{i}=\frac{e^{i}}{\sum_{j=1}^{3}e^{j}}e_{1}=3,e_{2}=1,e_{3}=-3

  2. 对得到的列表每个元素求对数再求相反数得到[0.1291089166980569, 2.1291089016197424, 6.1291089069914575]。
    log_soft = [-math.log(i) for i in softmax.eval()[0]]#[0.1291089166980569, 2.1291089016197424, 6.1291089069914575]
    cross_entropy = log_soft[2]#-log(0.00217852) = 6.12910890699

    由于标签是2,所以交叉熵取第3个数6.12910890699。

对一个输入为x的3分类结果,logits=[3,1,-3]的意义是x对应标签1、2、3的logit(可以理解为代价)分别为3、1、-3。因为logit表示代价不直观,所以转化为log_soft。从log_soft可以看出x对应标签1的代价最小(0.129),对应标签2的代价最大(6.129)。labels给出的实际标签是2,所以代价就是6.129。

完整代码:

import tensorflow as tf
import math

sess = tf.InteractiveSession()

batch_size=1
num_classes=3
logits = tf.constant([3,1,-3], shape=[batch_size, num_classes], dtype=tf.float32)
labels = tf.constant([2], shape=[batch_size,], dtype=tf.int32)

softmax = tf.nn.softmax(logits)
print"\n--softmax:",softmax.eval()
log_soft = [-math.log(i) for i in softmax.eval()[0]]
print"--1:",log_soft[2]

print"--2:",tf.losses.sparse_softmax_cross_entropy(logits=logits, labels=labels).eval()

temp = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=labels)
print"--3:",tf.reduce_mean(temp).eval()

 

  • 2
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值