假设为6分类,有4个样本的标签为1, 1, 4, 3。利用numpy进行one_hot编码:
import numpy as np
np.eye(6)[np.array([1, 1, 4, 3])]
one hot后的输出为:
array([[0., 1., 0., 0., 0., 0.],
[0., 1., 0., 0., 0., 0.],
[0., 0., 0., 0., 1., 0.],
[0., 0., 0., 1., 0., 0.]])
其中输出的行个数为样本的个数,列数为几分类的个数。