TF day2 里面说到神经网络的工作, 得到的预测值与实际的 label 比较得到了误差,我们的目标是让这个误差越来越小。
那么这个误差怎么定义呢? 这就是损失函数了~~~
分类问题和回归问题都属于监督学习.
对于分类问题, 设置 n 个output node, 其中 n 为类别的个数, 因此输出向量是, i表示第i个样本. 在理想情况下,如果一个样本属于类别 k ,那么这个节点的输出值为1, 其他输出值为0, 以minist中数字1为例, 即期望向量为[1,0,0,0,0,0,0…].
怎么判断输出向量与期望向量有多接近呢? 交叉熵是常用的判别方法之一.
因为交叉熵比较的是两个概率分布的距离, 先对预测向量进行softmax处理. 通过softmax层, 神经网络的输出变成一个概率分布.
交叉熵:
在神经网络中,p代表真实值,q表示预测值.
代码实现:
import tensorflow as tf
import pandas as pd
import numpy as np
df = pd.read_excel('/home/pan-xie/PycharmProjects/ML/jxdc/打分/项目评分表last.xls')
X = df.iloc[3:5,3:6]
logits = np.array(X,dtype=float)
print(logits)
Y = df.iloc[3:5,2]
label = []
for i in Y:
if i =='B':
label.append([1,0,0])
elif i =='C':
label.append([0,1,0])
else:
label.append([0,0,1])
print(label)
y = tf.nn.softmax(logits)
cross_entroy = -tf.reduce_sum(label*tf.log(y))
cross_entroy2 = tf.reduce_sum(tf.nn.softmax_cross_entropy_with_logits(logits=logits,labels=label))
with tf.Session() as sess:
tf.global_variables_initializer().run()
softmax = sess.run(y)
cross1 = sess.run(cross_entroy)
cross2 = sess.run(cross_entroy2)
print("softmax:")
print(softmax)
print("for two steps")
print(cross1)
print(cross2)
Tensorflow将交叉熵和softmax进行了同一的封装, 提供了几个函数:
1. tf.nn.softmax_cross_entropy_with_logits()
softmax_cross_entropy_with_logits(_sentinel=None,labels=None, logits=None,dim=-1, name=None):
参数:
- logits 对应预测值,未经过softmax处理层的输出值. 如果有batch的话,它的大小就是[batchsize,num_classes]
- labels 对应真实值 , 这里要好好想想真实值怎么带入?在minist中是one-hot vector, 但类别是A,B,C,D呢?
注意!!!这个函数的返回值并不是一个数,而是一个向量,如果要求交叉熵,我们要再做一步tf.reduce_sum 操作,就是对向量里面所有元素求和,最后才得到交叉熵,如果求loss,则要做一步tf.reduce_mean操作,对向量求均值!
2) tf.nn.sparse_softmax_cross_entropy_with_logits()
tf.nn.sparse_softmax_cross_entropy_with_logits(_sentinel=None,labels=None, logits=None, name=None)
与tf.nn.softmax_cross_entropy_with_logits()十分相似,唯一的区别在于labels,该函数的标签labels要求是排他性的即只有一个正确类别,labels的形状要求是[batch_size] 而值必须是从0开始编码的int32或int64,而且值范围是[0, num_class)
3)tf.nn.sigmoid_cross_entropy_with_logits(_sentinel=None,labels=None, logits=None, name=None)
4) tf.nn.weighted_cross_entropy_with_logits(targets, logits, pos_weight, name=None)
参考:
http://blog.csdn.net/marsjhao/article/details/72630147
http://blog.csdn.net/u014595019/article/details/52562159