一、环境
TensorFlow API r1.12
CUDA 9.2 V9.2.148
cudnn64_7.dll
Python 3.6.3
Windows 10
二、官方说明
将输入的 indices 转化为 one-hot 编码形式
indices 中指定的位置取值为 one_value 参数值,其他的位置都取值 off_value 参数值
参数 one_value 和 参数 off_value 的数据类型必须相同,如果指定了 dtype,就必须都为该数据类型
如果参数 one_value 没有指定,默认取 1 ,类型为指定的 dtype
如果参数 off_value 没有指定,默认取 0 ,类型为指定的 dtype
如果输入参数 indices 的阶是 N,则输出数据的阶 N+1;新轴在参数 axis 的维度上添加(不指定 axis 时默认添加在最后面的维度)
如果 indices 是标量,输出结果的形状为长度为 depth 的向量
如果 indices 是长度为 features 的向量,输出结果的形状为:
features x depth if axis == -1
depth x features if axis == 0
如果 indices 是形状为 [batch, features] 的矩阵,输出结果的形状为:
batch x features x depth if axis == -1
batch x depth x features if axis == 1
depth x batch x features if axis == 0
如果参数 dtype 不指定,该方法默认假定数据格式与参数 on_value 或 off_value 相同,如果 dtype、on_value 和 off_value 都不指定,则 dtype 默认是 tf.float32
注意:如果输出结果是非数字形式,如:tf.string、tf.bool 等,则 on_value 和 off_value 都必须设置
https://tensorflow.google.cn/api_docs/python/tf/one_hot
tf.one_hot(
indices,
depth,
on_value=None,
off_value=None,
axis=None,
dtype=None,
name=None
)
参数:
indices:值为索引的张量
depth:指定独热编码维度的标量
on_value:索引 indices[j] = i 位置处填充的标量,默认为 1
off_value:索引 indices[j] != i 所有位置处填充的标量,默认为 0
axis:填充的轴,默认为 -1(最里面的新轴)
dtype:输出张量的数据格式
name:可选参数,操作的名称
返回:
独热编码 one-hot 张量
三、实例
(1)一维列表形式的整型类别标签转换为 one-hot 类别标签形式
>>> import tensorflow as tf
>>> labels = [0,1,2]
>>> one_hot_labels = tf.one_hot(indices=labels,depth=3, on_value=1, off_value=0, axis=-1, dtype=tf.int32, name="one-hot")
>>> one_hot_labels
<tf.Tensor 'one-hot_1:0' shape=(3, 3) dtype=int32>
>>> with tf.Session() as sess:
... print(sess.run(one_hot_labels))
...
[[1 0 0]
[0 1 0]
[0 0 1]]
(2)二维列表形式的整型类别标签转换为 one-hot 类别标签形式
>>> import tensorflow as tf
>>> labels = [[0,1],[2,3]]
>>> labels
[[0, 1], [2, 3]]
>>> one_hot_labels = tf.one_hot(indices=labels,depth=3, on_value=1.0, off_value=0.0, axis=-1)
>>> one_hot_labels
<tf.Tensor 'one_hot:0' shape=(2, 2, 3) dtype=float32>
>>> with tf.Session() as sess:
... print(sess.run(one_hot_labels))
...
[[[1. 0. 0.]
[0. 1. 0.]]
[[0. 0. 1.]
[0. 0. 0.]]]
四、注意事项
使用参数“dtype”定义输出张量的数据格式时,一定要参数“on_value”和“off_value”的数据格式对应,否则会报错!
如:dtype=tf.float32,而 on_value=1, off_value=0,即前者指浮点型,后两者为整形,会报错:
“TypeError: dtype <dtype: 'int32'> of on_value does not match dtype parameter <dtype: 'float32'>”