该函数用于将输入转换成one-hot
形式:
tf.one_hot(indices, depth, on_value, off_value, axis)
indices
:非负整数表示的标签列表,len(indices)
就是分类的类别数。tf.one_hot
返回的张量的阶数为indeces
的阶数加上1
。当indices
的某个分量取-1
时,即对应的向量没有独热值。depth
:每个独热向量的维度。on_value
:独热值。off_value
:非独热值。axis
:指定第几阶为depth
维独热向量,默认为-1
,即指定张量的最后一维为独热向量。例如对于一个2
阶张量而言,axis = 0
时,每个列向量是一个独热的depth
维向量;axis = 1
时,每个行向量是一个独热的depth
维向量。
import tensorflow as tf
import numpy as np
z = np.random.randint(0, 10, size=[10])
y = tf.one_hot(z, 10, on_value=1, off_value=None, axis=0)
with tf.Session()as sess:
print(z)
print(sess.run(y))
执行结果:
[4 4 5 8 9 0 5 0 3 1]
[[0 0 0 0 0 1 0 1 0 0]
[0 0 0 0 0 0 0 0 0 1]
[0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 0 0 1 0]
[1 1 0 0 0 0 0 0 0 0]
[0 0 1 0 0 0 1 0 0 0]
[0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 0 0 0 0]
[0 0 0 1 0 0 0 0 0 0]
[0 0 0 0 1 0 0 0 0 0]]