tf.one_hot()用法与理解

    前几天看代码时遇到tf.one_hot()函数,当时没太关注,代码正常运行就OK,终于有时间回头好好看看这个函数。

    先来看看官方文档上的解释:

tf.one_hot(
    indices,
    depth,
    on_value=None,
    off_value=None,
    axis=None,
    dtype=None,
    name=None
)
Returns a one-hot tensor(返回一个one_hot张量).

The locations represented by indices in indices take value on_value, while all other locations take value off_value.
(由indices指定的位置将被on_value填充, 其他位置被off_value填充).

on_value and off_value must have matching data types. If dtype is also provided, they must be the same data type as specified by dtype.
(on_value和off_value必须具有相同的数据类型).

If on_value is not provided, it will default to the value 1 with type dtype.

If off_value is not provided, it will default to the value 0 with type dtype.

If the input indices is rank N, the output will have rank N+1. The new axis is created at dimension axis (default: the new axis is appended at the end).
(如果indices是N维张量,那么函数输出将是N+1维张量,默认在最后一维添加新的维度).

If indices is a scalar the output shape will be a vector of length depth.
(如果indices是一个标量, 函数输出将是一个长度为depth的向量)

If indices is a vector of length features, the output shape will be:

  features x depth if axis == -1.
(如果indices是一个长度为features的向量,则默认输出一个features*depth形状的张量)
  depth x features if axis == 0.
(如果indices是一个长度为features的向量,axis=0,则输出一个depth*features形状的张量)

If indices is a matrix (batch) with shape [batch, features], the output shape will be:

  batch x features x depth if axis == -1
(如果indices是一个形状为[batch, features]的矩阵,axis=-1(默认),则输出一个batch * features * depth形状的张量)

  batch x depth x features if axis == 1
(如果indices是一个形状为[batch, features]的矩阵,axis=1,则输出一个batch * depth * features形状的张量)
  depth x batch x features if axis == 0
(如果indices是一个形状为[batch, features]的矩阵,axis=0,则输出一个depth * batch * features形状的张量)

姑且把这个函数翻译成“独热编码”, 编码只是数据的一种表示方法,独热编码就是把要把部分数据与其他数据明确的区分开来,为此我们采用了0和1来区分不同的数据,传入one_hot()函数的是数据的indices指定要编码的数据位置, tf给的example可以很直观的解释这一过程:

indices = [0, 1, 2]  #输入数据(是个向量)需要编码的索引是[0,1,2]
depth = 3
tf.one_hot(indices, depth)  # output: [3 x 3]
# [[1., 0., 0.],
#  [0., 1., 0.],
#  [0., 0., 1.]]

indices = [0, 2, -1, 1]  #输入数据(是个向量)的需要编码的索引是[0,2,-1,1]
depth = 3
tf.one_hot(indices, depth,
           on_value=5.0, off_value=0.0,
           axis=-1)  # output: [4 x 3]
# [[5.0, 0.0, 0.0],  # one_hot(0)  对位置0处的数据进行one_hot编码
#  [0.0, 0.0, 5.0],  # one_hot(2)  对位置2处的数据进行one_hot编码
#  [0.0, 0.0, 0.0],  # one_hot(-1) 对位置-1处的数据进行one_hot编码
#  [0.0, 5.0, 0.0]]  # one_hot(1)  对位置1处的数据进行one_hot编码

indices = [[0, 2], [1, -1]]   #输入数据是个矩阵
depth = 3
tf.one_hot(indices, depth,
           on_value=1.0, off_value=0.0,
           axis=-1)  # output: [2 x 2 x 3]
# [[[1.0, 0.0, 0.0],   # one_hot(0)  对位置(0,0)处的数据进行one_hot编码
#   [0.0, 0.0, 1.0]],  # one_hot(2)  对位置(0,2)处的数据进行one_hot编码
#  [[0.0, 1.0, 0.0],   # one_hot(1)  对位置(1,1)处的数据进行one_hot编码
#   [0.0, 0.0, 0.0]]]  # one_hot(-1) 对位置(1,-1)处的数据进行one_hot编码

最后我们的数据就被编码成这种形式的张量: one_hot tensor, 注意axis代表新插入维度的位置.

  • 21
    点赞
  • 74
    收藏
    觉得还不错? 一键收藏
  • 3
    评论
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值