Tensorflow中one_hot() 函数用法

官网默认定义如下:
one_hot(indices, depth, on_value=None, off_value=None, axis=None, dtype=None, name=None)
该函数的功能主要是转换成one_hot类型的张量输出。


参数功能如下:
  1)indices中的元素指示on_value的位置,不指示的地方都为off_value。indices可以是向量、矩阵。
  2)depth表示输出张量的尺寸,indices中元素默认不超过(depth-1),如果超过,输出为[0,0,···,0]
  3)on_value默认为1
  4)off_value默认为0
  5)dtype默认为tf.float32


下面用几个例子说明一下:
1. indices是向量
 1 import tensorflow as tf
 2 
 3 indices = [0,2,3,5]
 4 depth1 = 6   # indices没有元素超过(depth-1)
 5 depth2 = 4   # indices有元素超过(depth-1)
 6 a = tf.one_hot(indices,depth1)
 7 b = tf.one_hot(indices,depth2)
 8 
 9 with tf.Session() as sess:
10     print('a = \n',sess.run(a))
11     print('b = \n',sess.run(b))

运行结果:

# 输入是一维的,则输出是一个二维的
a = [[1. 0. 0. 0. 0. 0.] [0. 0. 1. 0. 0. 0.] [0. 0. 0. 1. 0. 0.] [0. 0. 0. 0. 0. 1.]]      # shape=(4,6) b = [[1. 0. 0. 0.] [0. 0. 1. 0.] [0. 0. 0. 1.] [0. 0. 0. 0.]]          # shape=(4,4)

2. indices是矩阵

 1 import tensorflow as tf
 2 
 3 indices = [[2,3],[1,4]]
 4 depth1 = 9   # indices没有元素超过(depth-1)
 5 depth2 = 4   # indices有元素超过(depth-1)
 6 a = tf.one_hot(indices,depth1)
 7 b = tf.one_hot(indices,depth2)
 8 
 9 with tf.Session() as sess:
10     print('a = \n',sess.run(a))
11     print('b = \n',sess.run(b))

运行结果:

# 输入是二维的,则输出是三维的
a = [[[0. 0. 1. 0. 0. 0. 0. 0. 0.] [0. 0. 0. 1. 0. 0. 0. 0. 0.]] [[0. 1. 0. 0. 0. 0. 0. 0. 0.] [0. 0. 0. 0. 1. 0. 0. 0. 0.]]]    # shape=(2,2,9) b = [[[0. 0. 1. 0.] [0. 0. 0. 1.]] [[0. 1. 0. 0.] [0. 0. 0. 0.]]]             # shape=(2,2,4)

 

 

转载于:https://www.cnblogs.com/muzidaitou/p/11262820.html

  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值