def acc(output, label):
# output: (batch, num_output) float32 ndarray
# label: (batch, ) int32 ndarray
return (output.argmax(axis=1) == label.astype(‘float32’)).mean().asscalar()
在Gluon文档里有这个计算accuracy的函数,就一行看不懂,分析一下。
首先argmax,argmax的意思是返回最大值的坐标。
- axis缺省为全局最大(直接用报错,可以np.argmax(a) )
- axis = 0 为每列最大
- axis = 1为每行最大
x = nd.array(((1,2,3),(3,4,5)))
>>> x.argmax(axis=1)
[2. 2.]
<NDArray 2