在keras中做深度网络预测时,有这两个预测函数model.predict_classes(test) 和model.predict(test),本例中是多分类,标签经过了one-hot编码,如[1,2,3,4,5]是标签类别,经编码后为[1 0 0 0 0],[0 1 0 0 0]...[0 0 0 0 1]
- model.predict_classes(test)预测的是类别,打印出来的值就是类别号
同时只能用于序列模型来预测,不能用于函数式模型
predict_test = model.predict_classes(X_test).astype('int')
inverted = encoder.inverse_transform([predict_test])
print(predict_test)
print(inverted[0])
[1 0 0 ... 1 0 0] [2. 1. 1. ... 2. 1. 1.]
- model.predict(test)预测的是数值,而且输出的还是5个编码值,不过是实数,预测后要经过argmax(predict_test,axis=1)
predict_test = model.predict(X_test)
predict = argmax(predict_test,axis=1) #axis = 1是取行的最大值的索引,0是列的最大值的索引
inverted = encoder.inverse_transform([predict])
print(predict_test[0:3])
print(argmax(predict_test,axis=1))
print(inverted)
[[9.9992561e-01 6.9890179e-05 2.3676146e-06 1.9608513e-06 2.5582506e-07] [9.9975246e-01 2.3708738e-04 4.9365349e-06 5.2166861e-06 3.3735736e-07] [9.9942291e-01 5.5233808e-04 8.9857504e-06 1.5617061e-05 2.4388814e-07]] [0 0 0 ... 0 0 0] [[1. 1. 1. ... 1. 1. 1.]]由于前几个和后几个每个预测值编码都是第一列最大,所以索引是0,反编码后是1