1. mxnet.metric.check_lable_shapes(labels, preds, wrap=False, shape=False)
labels: data's labels, ndarray
preds: predicted values, ndarray
wrap : boolean, if True, 如果 labels/preds 是 single NDarray的话就把它们打包成 list.
shape : boolean, if True的话,就check labels and preds's shape, 否则就仅仅check它们的长度。
2. mxnet.metric.EvalMetric(name, output_names=None, label_names=None, **kwargs)
这是一个类!并且是所有评价度量的基类。这个类中提供了通常的评价度量的接口, 不应该直接使用这个类,而应该创建一个子类来继承它。
这里的参数,name是要创建的metric实例的名字,output_names是predictions的名字, 应该在更新时使用
label_names是labels's 名字,应该在更新时使用。
然后这个类中有很多方法
2.1 ___init__(self, name, output_names=None, label_names=None, **kwargs)
这个是用来初始化的,不用说了。
2.2 __str__(self)
这个里面写的是:
def __str__(self):
return "EvalMetric: {}".format(dict(self.get_name_value()))
所以我们去看一下get_name_value()
2.3
def get_name_value(self)
name, value = self.get()
if not isinstance(name, list):
name = [name]
if not isinstance(value, list):
value = [value]
return list(zip(name, value))
所以我们又要先看一下get函数
2. 4
def get(self):
if self.num_inst == 0:
return (self.name, float('nan'))
else:
return (self.name, self.sum_metric / self.num_inst)
然后接下来看看
2.5 接下来的这两个函数都用到了update
def get_config(self):
这个直接就是获得更新后的,metric, name, output_names, label_name
和
2.6
def update_dict(self, label, pred):
这个就是更新了,pred和label分别根据output_names和label_names
然后再调用
update(self, labels, preds)
2.7
def update(self, labels, preds):
注意这个是一定要实现的,不然直接会报错。
2.8
def create(metric, *args,**args),这个在上一个博文中已经提到了。