TensorFlow one-hot 编码 tf.one_hot 的基本用法及实例代码

一、环境

TensorFlow API r1.12

CUDA 9.2 V9.2.148

cudnn64_7.dll

Python 3.6.3

Windows 10

 

二、官方说明

将输入的 indices 转化为 one-hot 编码形式

indices 中指定的位置取值为 one_value 参数值,其他的位置都取值 off_value 参数值

参数 one_value 和 参数 off_value 的数据类型必须相同,如果指定了 dtype,就必须都为该数据类型

如果参数 one_value 没有指定,默认取 1 ,类型为指定的 dtype

如果参数 off_value 没有指定,默认取 0 ,类型为指定的 dtype

如果输入参数 indices 的阶是 N,则输出数据的阶 N+1;新轴在参数 axis 的维度上添加(不指定 axis 时默认添加在最后面的维度)

如果 indices 是标量,输出结果的形状为长度为 depth 的向量

如果 indices 是长度为 features 的向量,输出结果的形状为:

features x depth if axis == -1
depth x features if axis == 0

 如果 indices 是形状为 [batch, features] 的矩阵,输出结果的形状为:

batch x features x depth if axis == -1
batch x depth x features if axis == 1
depth x batch x features if axis == 0

如果参数 dtype 不指定,该方法默认假定数据格式与参数 on_value 或 off_value 相同,如果 dtype、on_value 和 off_value 都不指定,则 dtype 默认是 tf.float32

注意:如果输出结果是非数字形式,如:tf.string、tf.bool 等,则 on_value 和 off_value 都必须设置

https://tensorflow.google.cn/api_docs/python/tf/one_hot

tf.one_hot(
    indices,
    depth,
    on_value=None,
    off_value=None,
    axis=None,
    dtype=None,
    name=None
)

参数:

indices:值为索引的张量

depth:指定独热编码维度的标量

on_value:索引 indices[j] = i 位置处填充的标量,默认为 1

off_value:索引 indices[j] != i 所有位置处填充的标量,默认为 0

axis:填充的轴,默认为 -1(最里面的新轴)

dtype:输出张量的数据格式

name:可选参数,操作的名称

返回:

独热编码 one-hot 张量

 

三、实例

(1)一维列表形式的整型类别标签转换为 one-hot 类别标签形式

>>> import tensorflow as tf
>>> labels = [0,1,2]
>>> one_hot_labels = tf.one_hot(indices=labels,depth=3, on_value=1, off_value=0, axis=-1, dtype=tf.int32, name="one-hot")
>>> one_hot_labels
<tf.Tensor 'one-hot_1:0' shape=(3, 3) dtype=int32>
>>> with tf.Session() as sess:
...     print(sess.run(one_hot_labels))
...
[[1 0 0]
 [0 1 0]
 [0 0 1]]

(2)二维列表形式的整型类别标签转换为 one-hot 类别标签形式

>>> import tensorflow as tf
>>> labels = [[0,1],[2,3]]
>>> labels
[[0, 1], [2, 3]]
>>> one_hot_labels = tf.one_hot(indices=labels,depth=3, on_value=1.0, off_value=0.0, axis=-1)
>>> one_hot_labels
<tf.Tensor 'one_hot:0' shape=(2, 2, 3) dtype=float32>
>>> with tf.Session() as sess:
...     print(sess.run(one_hot_labels))
...
[[[1. 0. 0.]
  [0. 1. 0.]]

 [[0. 0. 1.]
  [0. 0. 0.]]]

四、注意事项

使用参数“dtype”定义输出张量的数据格式时,一定要参数“on_value”和“off_value”的数据格式对应,否则会报错!

如:dtype=tf.float32,而 on_value=1, off_value=0,即前者指浮点型,后两者为整形,会报错:

“TypeError: dtype <dtype: 'int32'> of on_value does not match dtype parameter <dtype: 'float32'>”

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

csdn-WJW

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值