Top-k准确率
import tensorflow as tf
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' #限制控制台打印日志级别
tf.random.set_seed(2467)
def accuracy(output, target, topk=(1,)):
# output [10,6]
maxk = max(topk)
batch_size = target.shape[0]
pred = tf.math.top_k(output, maxk).indices # 前K个最大值的索引 [10,maxk]
# print('每行top-6 最大值',tf.math.top_k(output, maxk).values)
print('每行top-6 最大值下标'