最近用eval_image_classifier.py脚本测试多分类模型时,发现slim.metrics中Accuracy指标和自己计算的值有偏差,于是特意去看了源代码,发现此处计算的Accuracy其实是各类召回率的算术平均值,而且指标中给的Precision和Recall都是对于二分类来计算的,对于多分类模型,根本不适用,所以要想自己测试多分类模型的评价指标,需要自己单独进行测试,下面来分析源码。
(1)在代码中添加想要查看的指标
# Define the metrics:
names_to_values, names_to_updates = slim.metrics.aggregate_metric_map({
'TP':slim.metrics.streaming_true_positives(predictions, labels),
'TN':slim.metrics.streaming_true_negatives(predictions, labels),
'FP':slim.metrics.streaming_false_positives(predictions, labels),
'FN':slim.metrics.streaming_false_negatives(predictions, labels),
'Accuracy': slim.metrics.streaming_accuracy(predictions, labels),
'Precision': slim.metrics.streaming_precision(predictions, labels),
'Recall':slim.metrics.streaming_recall(predictions, labels),
'Recall_1': slim.metrics.streaming_recall_at_k(
logits, labels, 1),
})
查看各指标的值:
Accuracy = Recall_1=0.88
Precision = 0.93
Recall = 0.929
发现Accuracy等于Recall_1,且不满足(TP+TN)/(TP+FP+TN+FN)
这个问题有网友提到过https://stackoverflow.com/questions/43408200/tf-slim-computation-of-accuracy
(2)分析代码中各个指标的计算方法
首先来看看四个变量TP、TN、FP、FN的值是如何计算的
TP:
def true_positives(labels,
predictions,
weights=None,
metrics_collections=None,
updates_collections=None,
name=None):
if context.executing_eagerly():
raise RuntimeError('tf.metrics.true_positives is not '
'supported when eager execution is enabled.')
with variable_scope.variable_scope(name, 'true_positives',
(predictions, labels, weights)):
predictions, labels, weights = _remove_squeezable_dimensions(
predictions=math_ops.cast(predictions, dtype=dtypes.bool),
labels=math_ops.cast(labels, dtype=dtypes.bool),
weights=weights)
# 将标签和预测值转为bool型变量,label等于0为false负例,大于0的为true,正例
is_true_positive = math_ops.logical_and(
math_ops.equal(labels, True), math_ops.equal(predictions, True))
#统计的是标签和预测值同时为true的个数
return _count_condition(is_true_positive, weights, metrics_collections,
updates_collections)
FP:
def false_positives(labels,
predictions,
weights=None,
metrics_collections=None,
updates_collections=None,
name=None):
if context.executing_eagerly():
raise RuntimeError('tf.metrics.false_positives is not supported when '
'eager execution is enabled.')
with variable_scope.variable_scope(name, 'false_positives',
(predictions, labels, weights)):
predictions, labels, weights = _remove_squeezable_dimensions(
predictions=math_ops.cast(predictions, dtype=dtypes.bool),
labels=math_ops.cast(labels, dtype=dtypes.bool),
weights=weights)
is_false_positive = math_ops.logical_and(
math_ops.equal(labels, False), math_ops.equal(predictions, True))
#统计的是标签为false、预测值为true的个数
return _count_condition(is_false_positive, weights, metrics_collections,
updates_collections)
另外两个类似,我就不展示了,从源码可以看出,四个值的计算仅仅是针对二分类,标签为0,1时适用,对于multi-class,是不适用的。
那么基于以上四个值计算的精确率和召回率也是不适用多分类的。
Precison = (TP)/(TP+FP)
Recall = (TP)/(TP+FN)
再来看看Accuracy:
def accuracy(labels,
predictions,
weights=None,
metrics_collections=None,
updates_collections=None,
name=None):
if context.executing_eagerly():
raise RuntimeError('tf.metrics.accuracy is not supported when eager '
'execution is enabled.')
predictions, labels, weights = _remove_squeezable_dimensions(
predictions=predictions, labels=labels, weights=weights)
predictions.get_shape().assert_is_compatible_with(labels.get_shape())
if labels.dtype != predictions.dtype:
predictions = math_ops.cast(predictions, labels.dtype)
is_correct = math_ops.to_float(math_ops.equal(predictions, labels))
#查看标签值和预测值是否相等。
return mean(is_correct, weights, metrics_collections, updates_collections,
name or 'accuracy')
#统计每一类相等的个数(每一类的召回率),并求平均值