首先肯定需要解释下什么叫做独热编码(one-hot encoding),独热编码一般是在有监督学习中对数据集进行标注时候使用的,指的是在分类问题中,将存在数据类别的那一类用X表示,不存在的用Y表示,这里的X常常是1, Y常常是0。
举个例子:
我们是在做一个1和0数字识别的应用,就是将一个图片分成1这一类还是0这一类的分类问题,我们一般采用softmax来做分类。
softmax回归简述:
对于图像来说,最后提取的特征被放在1024*1的向量中(这里假设最终提取到的特征是1024个),那么我们需要把这1024个对应的特征值输出全连接神经网络,最后经过几层的降维最终输出我们需要分的类别。(假设我们经过2层全连接层最终将所有图片分为1和0两类,所以最后的输出层只有两个神经元输出值)所以,这两个神经元的输出值就代表着0和1的两个分类。怎么确定到底是0还是1呢,这就要看最终输出的0和1的值与我们事先标注的0和1的标签哪一个更接近。但是这种总归是不靠谱,因为输出的值有大有小,不稳定,误差大。于是我们采用softmax回归一个式子将最终这些神经元输出的值全部转换为概率值。也就是在最后的两个神经元的输出之后再加上一个softmax回归层,之后再输出的两个值就分别是成为这两类的概率了。就相当于一个按分的类数将每个类的输出值归一化的过程。(所有类别输出总和=1)
所以,一张图片经过神经网络后得到的两个概率值,可以认为是这张图片分类到0和1的可能性。
好了,训练数据的输出的问题搞定了,那么标签怎么办呢?因为现在标签还是用值来表示的,比如0的一张图片,它的标签就是0;1的一张图片,它的标签就是1,这种用值来表示就显得不靠谱。那如果是0-9的识别,那么难道9对应的图片,它的标签就是9吗?显然不是!
对于一幅图片的标签来说,他也应该有等于分类数目的标注值。比如,我们最终识别0和1分两类,那么一张训练图片最终的输出有两个值(分别是属于1和属于0的概率),那么它的标签也应该有两个值来对应,那么就也应该是属于1和属于0的概率,这样才能训练出现误差嘛。但是对于1的图片它的标签属于1的概率自然是1,属于0的概率自然是0,所以其实标签的值应该是一个有序的数列【1,0】或【0,1】。所以构造标签,只需要在对应的类上写1,在其他的类上写0就可以了。
好了,说到这里,我也是彻底明白了。所以对于分类问题,我们既需要对输出数据进行softmax回归,也要对标签进行以上处理。
好了,罗里吧嗦,我在接着说tensorflow中对于标签数据的处理。采用tf.one_hot将样本的label转换成one_hot的形式,方便进行softmax计算。
one_hot(
indices,#输入,这里是一维的
depth,# one hot dimension.
on_value=None,#output 默认1
off_value=None,#output 默认0
axis=None,
dtype=None,
name=None
)
需要指定indices,和depth,其中depth是编码深度,on_value和off_value相当于是编码后的开闭值,如同我们刚才描述的X值和Y值,需要和dtype相同类型(指定了dtype的情况下),axis指定编码的轴。这里给个小的实例:
import tensorflow as tf
var0 = tf.one_hot(indices=[1, 2, 3], depth=3, axis=0)
var1 = tf.one_hot(indices=[1, 2, 3], depth=4, axis=0)
var2 = tf.one_hot(indices=[1, 2, 3], depth=4, axis=1)
# axis=1 按行排
var3 = tf.one_hot(indices=[1, 2, 3], depth=4, axis=-1)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
a0 = sess.run(var0)
a1 = sess.run(var1)
a2 = sess.run(var2)
a3 = sess.run(var3)
print("var0(axis=0 depth=3)n",a0)
print("var1(axis=0 depth=4P)n",a1)
print("var2(axis=1)n",a2)
print("var3(axis=-1)n",a3)
结果:
我们实例中的例子:
train_labels_one_hot = tf.one_hot(train_labels, 2, on_value=1.0, off_value=0.0)
test_labels_one_hot = tf.one_hot(test_labels, 2, on_value=1.0, off_value=0.0)