【分类问题】【numpy】网络输出预测向量->标签
即 取出每一项的最大值的索引
probabilities是shape=(B,N,class)的深度学习输出。N个样本,每个样本有长为(class)的可能性向量。
选取class中最大的一位的索引,作为预测标签
使用代码
label= np.argmax(probabilities, axis=-1)
以压缩axis=-1最后一维为准,找出每一个项的最大值的索引。
输出的label的shape=(B,N)。
Example
测试代码:
probabilities = np.random.randn(4,5)
label = np.argmax(probabilities,axis=-1)
结果如下