模型评估
符号 | 含义 |
---|---|
y y y | 真实值【一个向量】 |
y ^ \hat y y^ | 预测值【一个向量】 |
分类评估(Metrics for Classification)
准确度(Accuracy)
正确的预测/样本个数
Eg:
y y y = (0,0,1,0,1)
y
^
\hat y
y^ = (1,0,1,0,1)
A
c
c
u
a
r
c
y
=
s
u
m
(
y
=
=
y
^
)
/
l
e
n
(
y
)
=
s
u
m
(
(
0
,
0
,
1
,
0
,
1
)
=
=
(
1
,
0
,
1
,
0
,
1
)
)
/
5
=
4
/
5
\begin{aligned} Accuarcy &= sum(y == \hat y)/len(y) \\ &= sum((0,0,1,0,1) == (1,0,1,0,1))/5\\ &= 4/5 \end{aligned}
Accuarcy=sum(y==y^)/len(y)=sum((0,0,1,0,1)==(1,0,1,0,1))/5=4/5
精度(Precision)
当有时候数据正负类分布不平衡,此时得到的准确度就有一点偏离模型真正能够达到的准确度,因此需要其他指标来权衡利弊。
每个类别里预测等于实际的个数/在预测中该类别出现的次数
Eg:
y y y = (0,0,1,0,1)
y ^ \hat y y^ = (1,0,1,0,1)
预测正确率
P
r
e
c
i
s
i
o
n
=
s
u
m
(
y
=
=
1
a
n
d
y
^
=
=
1
)
/
s
u
m
(
y
^
=
=
1
)
=
s
u
m
(
(
0
,
0
,
1
,
0
,
1
)
=
=
1
a
n
d
(
1
,
0
,
1
,
0
,
1
)
=
=
1
)
/
s
u
m
(
(
1
,
0
,
1
,
0
,
1
)
=
=
1
)
=
2
/
3
\begin{aligned} Precision &= sum(y==1\ and\ \hat y==1)/sum(\hat y==1) \\ &= sum((0,0,1,0,1)==1\ and\ (1,0,1,0,1)==1)/ sum((1,0,1,0,1)==1)\\ &= 2/3 \end{aligned}
Precision=sum(y==1 and y^==1)/sum(y^==1)=sum((0,0,1,0,1)==1 and (1,0,1,0,1)==1)/sum((1,0,1,0,1)==1)=2/3
召回率(Recall)
每个类别里预测等于实际的个数/在实际中该类别出现的次数
Eg:
y y y = (0,0,1,0,1)
y ^ \hat y y^ = (1,0,1,0,1)
预测正确率
R
e
c
a
l
l
=
s
u
m
(
y
=
=
1
a
n
d
y
^
=
=
1
)
/
s
u
m
(
y
=
=
1
)
=
s
u
m
(
(
0
,
0
,
1
,
0
,
1
)
=
=
1
a
n
d
(
1
,
0
,
1
,
0
,
1
)
=
=
1
)
/
s
u
m
(
(
0
,
0
,
1
,
0
,
1
)
=
=
1
)
=
2
/
2
\begin{aligned} Recall &= sum(y==1\ and\ \hat y==1)/sum( y==1) \\ &= sum((0,0,1,0,1)==1\ and\ (1,0,1,0,1)==1)/ sum((0,0,1,0,1)==1)\\ &= 2/2 \end{aligned}
Recall=sum(y==1 and y^==1)/sum(y==1)=sum((0,0,1,0,1)==1 and (1,0,1,0,1)==1)/sum((0,0,1,0,1)==1)=2/2
F1
平衡精度和召回率
F 1 = 2 ∗ P ∗ R / ( P + R ) F1 = 2*P*R/(P+R) F1=2∗P∗R/(P+R)
Eg:
在以上的例子中计算F1
F
1
=
2
∗
P
∗
R
/
(
P
+
R
)
=
(
2
∗
2
/
3
∗
2
/
2
)
/
(
2
/
3
+
2
/
2
)
=
4
/
5
\begin{aligned} F1 &= 2*P*R/(P+R) \\ &= (2*2/3*2/2)/(2/3+2/2)\\ &= 4/5 \end{aligned}
F1=2∗P∗R/(P+R)=(2∗2/3∗2/2)/(2/3+2/2)=4/5
问答评估(Metrics for Question Answering)
绝对匹配(Exact Match)
衡量预测答案是否与标准答案完全一致,Exact Match是问答系统的一种常见的评价标准,它用来评价预测中匹配到正确答案(ground truth answers)的百分比。
EM 是用于 SQuAD 的主要指标之一。
实现公式
名称 | 符号 |
---|---|
预测答案字符串 | S_pred |
标准答案字符串 | S_ref |
E M = { 1 , if S p r e d = S r e f 0 , if S p r e d ≠ S r e f EM = \begin{cases} 1, & \text{if }S_{pred} = S_{ref} \\ 0, & \text{if } S_{pred}\neq S_{ref} \end{cases} EM={1,0,if Spred=Srefif Spred=Sref
也存在缺陷,如有些答案没有写全冠词the、an(预测答案是apple,标准答案是an apple),数字1和中文一,其实答案应该是对的
F1
衡量预测答案和标准答案的相似度,属于[0,1]
名称 | 符号 |
---|---|
预测答案字符串 | S_pred |
标准答案字符串 | S_ref |
预测答案单词集合 | W_pred |
标准答案单词集合 | W_ref |
集合W中单词个数 | |W| |
实现公式
两者集合取交集
W
m
a
t
c
h
=
W
p
r
e
d
∩
W
r
e
f
W_{match} = W_{pred} \cap W_{ref}
Wmatch=Wpred∩Wref
P
=
∣
W
m
a
t
c
h
∣
W
p
r
e
d
P = \frac{|W_{match}|}{W_{pred}}
P=Wpred∣Wmatch∣
R
=
∣
W
m
a
t
c
h
∣
W
r
e
f
R = \frac{|W_{match}|}{W_{ref}}
R=Wref∣Wmatch∣
F
1
=
2
∗
P
∗
R
/
(
P
+
R
)
F1 = 2*P*R/(P+R)
F1=2∗P∗R/(P+R)
代码
导入相关库
from __future__ import print_function
from collections import Counter
import string
import re
import argparse
import json
import sys
在进行匹配之间先进行一些处理
- 将单词小写
- 去除标点符号
- 去除冠词
- 根据空格进行分词
def normalize_answer(s):
def remove_articles(text):
return re.sub(r'\b(a|an|the)\b', ' ', text)
def white_space_fix(text):
return ' '.join(text.split())
def remove_punc(text):
exclude = set(string.punctuation)
return ''.join(ch for ch in text if ch not in exclude)
def lower(text):
return text.lower()
return white_space_fix(remove_articles(remove_punc(lower(s))))
计算F1
def f1_score(prediction, ground_truth):
# 首先把prediction和ground_truth标准化(即用上面的函数进行处理)
prediction_tokens = normalize_answer(prediction).split()
ground_truth_tokens = normalize_answer(ground_truth).split()
# 统计他们共有的字符
common = Counter(prediction_tokens) & Counter(ground_truth_tokens)
# 计算共有的字符的总量
num_same = sum(common.values())
if num_same == 0:
return 0
# 计算precision
precision = 1.0 * num_same / len(prediction_tokens)
# 计算recall
recall = 1.0 * num_same / len(ground_truth_tokens)
f1 = (2 * precision * recall) / (precision + recall)
return f1
计算EM
def exact_match_score(prediction, ground_truth):
return (normalize_answer(prediction) == normalize_answer(ground_truth))