看一段代码,主要看out变量那行开始
import numpy as np
import paddle
import paddle.nn as nn
from paddle.static import InputSpec
device = paddle.set_device('cpu') # or 'gpu'
input = InputSpec([None, 784], 'float32', 'x')
label = InputSpec([None, 1], 'int64', 'label')
net = nn.Sequential(
nn.Linear(784, 200),
nn.Tanh(),
nn.Linear(200, 10),
nn.Softmax())
model = paddle.Model(net, input, label)
model.prepare()
data = np.random.random(size=(4,784)).astype(np.float32)
out = model.predict_batch([data])
pred = (out[0].tolist())[0]
pridict_pred = pred.index(max(pred))
print(out[0][0][0])
print('----------------------------------------------')
print(out[0][0])
print('----------------------------------------------')
print(out[0])
print('----------------------------------------------')
print(out)
print(pred)
print(pridict_pred)
输出结果为:
0.08016773 ---------------------------------------------- [0.08016773 0.0288931 0.17969687 0.21065278 0.21539962 0.06590215 0.1056487 0.03769546 0.04641107 0.02953248] ---------------------------------------------- [[0.08016773 0.0288931 0.17969687 0.21065278 0.21539962 0.06590215 0.1056487 0.03769546 0.04641107 0.02953248] [0.06550846 0.03023597 0.36523694 0.25914553 0.0519002 0.06438435 0.05826315 0.03064454 0.03112181 0.04355912] [0.0454619 0.0548484 0.2597019 0.2091225 0.19776115 0.06514735 0.04594166 0.03004312 0.04092751 0.05104466] [0.02242672 0.05681366 0.34513056 0.16498643 0.14221473 0.05016565 0.09210744 0.02535165 0.05506216 0.04574092]] ---------------------------------------------- [array([[0.08016773, 0.0288931 , 0.17969687, 0.21065278, 0.21539962, 0.06590215, 0.1056487 , 0.03769546, 0.04641107, 0.02953248], [0.06550846, 0.03023597, 0.36523694, 0.25914553, 0.0519002 , 0.06438435, 0.05826315, 0.03064454, 0.03112181, 0.04355912], [0.0454619 , 0.0548484 , 0.2597019 , 0.2091225 , 0.19776115, 0.06514735, 0.04594166, 0.03004312, 0.04092751, 0.05104466], [0.02242672, 0.05681366, 0.34513056, 0.16498643, 0.14221473, 0.05016565, 0.09210744, 0.02535165, 0.05506216, 0.04574092]], dtype=float32)] [0.0801677331328392, 0.02889309823513031, 0.17969687283039093, 0.21065278351306915, 0.2153996229171753, 0.06590215116739273, 0.1056487038731575, 0.03769545629620552, 0.04641106724739075, 0.029532475396990776] 4
我认为pred = (out[0].tolist())[0]
意思是将是将一个nparray形式的数组转为矩阵形式,shape为一个4*1*10的矩阵,将第一个“列表”返回