tf.one_hot()
作用:转换为one-hot 编码格式,由于我们一般预测结尾使用softmax,导致结果全为one-hot形式,因此我们在做测试集时,需要将label转换为one-hot格式,或者将预测结果的one-hot格式转换为数组形式;
关键参数:indices , depth
indices: 传入tensor,如[1,0,3,2]
depth:one-hot的编码深度
代码示例:
one = np.array([1,0,3,2])
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
tensor_one = tf.convert_to_tensor(one)
print(sess.run(tf.one_hot(tensor_one,depth=4)))
结果,数组里的值全部为one-hot编码格式里1的下标单位:
[[0. 1. 0. 0.]
[1. 0. 0. 0.]
[0. 0. 0. 1.]
[0. 0. 1. 0.]]
如果此时将depth = 3,就会发现,one-hot 无法编码超出的范围,具体见代码:
one = np.array([1,0,3,2])
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
tensor_one = tf.convert_to_tensor(one)
print(sess.run(tf.one_hot(tensor_one,depth=3)))
结果,看到第三行就无法体现array[1,0,3,2]中的3,因此,depth 也就是array 的长度,平时可以使用depth = len(array):
[[0. 1. 0.]
[1. 0. 0.]
[0. 0. 0.]
[0. 0. 1.]]
接下来,考虑one-hot 转array
我们使用tf.argmax()来转换:
tf.argmax详细使用见链接:tf.argmax使用,点我
接下来我们就将之前转换结果转换回来
代码:
import numpy as np
one = np.array([1,0,3,2])
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
tensor_one = tf.convert_to_tensor(one)
one_hot_res = tf.one_hot(tensor_one,depth=4)
print(sess.run(tf.argmax(one_hot_res,0)))
结果如下:
one_hot_res = tf.one_hot(tensor_one,depth=4)
print(sess.run(tf.argmax(one_hot_res,0)))