今天看代码时看到了别人用np.eye实现one-hot编码,以前不知道这种用法,觉得很实用,所以记录一下。
import numpy as np
def onehot(label, num):
m = label
one_hot = np.eye(num)[m] # num为onehot编码的长度,m为编码前的数组(可以是高维ndarray)
return one_hot
X = np.ones(10000, dtype=np.uint8).reshape(10,10,10,10)
Y = onehot(X, 58)
print(Y.shape, X.shape)
print(Y)
结果: