tensorflow.one_hot
tf.one_hot(labels, depth, axis)
作用如下:
其中:label是y,depth是右边矩阵的深度(也就是有多少个类)
import numpy as np
import tensorflow as tf
def one_hot_matrix(lable,clas):
one_hot_matrix = tf.one_hot(indices=lable , depth=clas , axis=0)
with tf.Session() as sess:
reslut = sess.run(one_hot_matrix)
return reslut
y = np.array([1,2,3,4,5])
c = 6
print(one_hot_matrix(y,c))