one-hot encoding
one-hot encoding 一般是在有监督学习中对数据集进行标注时使用的,指的是在分类中,将存在数据类别的那一类用X表示,不存在的用Y表示,这里的X常常是1, Y常常是0。
举个例子:
比如我们有一个5分类问题,我们有数据
(
X
i
,
Y
i
)
(X_{i}, Y_{i})
(Xi,Yi),其中类别
Y
i
Y_{i}
Yi有5种取值,所以如果所以如果
Y
j
Y_{j}
Yj为第一类那么其独热编码为: [1,0,0,0,0],如果是第二类那么独热编码为:[0,1,0,0,0],也就是说只对存在有该类别的数的位置上进行标记为1,其他皆为0。这个编码方式经常用于多分类问题,特别是损失函数为交叉熵函数的时候。
tf.one_hot()
tf.one_hot(
indices, #输入,这里是一维的
depth, #one hot dimension
on_value=None, #默认1
off_value=None, #默认0
axis=None,
dtype=None,
name=None
)
需要指定indices和depth,其中depth是编码深度,on_value和off_policy相当于编码后开闭值,如同我们刚才描述的X和Y值,需要和dtype相同类型(指定了dtype的情况下),axis指定编码的轴。
实例:
import tensorflow as tf
var0 = tf.one_hot(indices=[1, 2, 3], depth=3, axis=0)
var1 = tf.one_hot(indices=[1, 2, 3], depth=4, axis=0)
var2 = tf.one_hot(indices=[1, 2, 3], depth=4, axis=1)
var3 = tf.one_hot(indices=[1, 2, 3], depth=4, axis=-1)
with tf.Session() as sess:
sess.run(tf.global_variable_initializer())
a0 = sess.run(var0)
a1 = sess.run(var1)
a2 = sess.run(var2)
a3 = sess.run(var3)
print("var0(axis=0 depth=3)\n",a0)
print("var1(axis=0 depth=4P)\n",a1)
print("var2(axis=1)\n",a2)
print("var3(axis=-1)\n",a3)
输出:
Reference:
https://www.jianshu.com/p/c5b4ec39713b