1.acc计算原理
sklearn中accuracy_score函数计算了准确率。
在二分类或者多分类中,预测得到的label,跟真实label比较,计算准确率。
在multilabel(多标签问题)分类中,该函数会返回子集的准确率。如果对于一个样本来说,必须严格匹配真实数据集中的label,整个集合的预测标签返回1.0;否则返回0.0.
2.acc的不适用场景:
在正负样本不平衡的情况下,准确率这个评价指标有很大的缺陷。比如在互联网广告里面,点击的数量是很少的,一般只有千分之几,如果用acc,即使全部预测成负类(不点击)acc也有 99% 以上,没有意义。因此,单纯靠准确率来评价一个算法模型是远远不够科学全面的。在类别不平衡没那么太严重时,该指标具有一定的参考意义。
3.metrics.accuracy_score()的使用方法
不管是二分类还是多分类,还是多标签问题,计算公式都为:
只是在多标签问题中,TP、TN要求更加严格,必须严格匹配真实数据集中的label。
sklearn.metrics.accuracy_score(y_true, y_pred, *, normalize=True, sample_weight=None)
输入参数:
y_true:真是标签。二分类和多分类情况下是一列,多标签情况下是标签的索引。
y_pred:预测标签。二分类和多分类情况下是一列,多标签情况下是标签的索引。
normalize:bool, optional (default=True),如果是false,正确分类的样本的数目(int);如果为true,返回正确分类的样本的比例,必须严格匹配真实数据集中的label,才为1,否则为0。
sample_weight:array-like of shape (n_samples,), default=None。Sample weights.
输出:
如果normalize == True
,返回正确分类的样本的比例,否则返回正确分类的样本的数目(int)。
4.例子
举一个多标签的例子,这里假设有21个标签。
数据格式:预测label是有阈值硬阶段得来,比如当预测得分大于0.5,则在这个索引下的label为1,否则为0。
id | label | pred_label | pred_label_score | pred_scores(模型输出的scores) |
1559592780 | [3] | [3] | [0.9060243964195251] | [0.03569700941443443, 0.025016790255904198, 0.010681516490876675, 0.9060243964195251, 0.03405195102095604, 0.01652703806757927, 0.01057326141744852, 0.015285834670066833, 0.03219904750585556, 0.01710071600973606, 0.015052232891321182, 0.012746844440698624, 0.009399563074111938, 0.012753037735819817, 0.008887830190360546, 0.011201461777091026, 0.013154321350157261, 0.010007181204855442, 0.015232570469379425, 0.011832496151328087, 0.014289622195065022] |
1559950270 | [3] | [] | [] | [0.0441354475915432, 0.07238972187042236, 0.011645170859992504, 0.007589259184896946, 0.25604453682899475, 0.08702245354652405, 0.27572867274284363, 0.00486581027507782, 0.01071715448051691, 0.010638655163347721, 0.005942077841609716, 0.03388604149222374, 0.003174690529704094, 0.006336248945444822, 0.007447054609656334, 0.004069846123456955, 0.06864038109779358, 0.003221432212740183, 0.010166178457438946, 0.014550245366990566, 0.018491217866539955] |
1559394894 | [3] | [3] | [0.6821054816246033] | [0.2968560457229614, 0.0307493656873703, 0.005526685621589422, 0.6821054816246033, 0.019207751378417015, 0.011433916166424751, 0.00833720900118351, 0.011756493709981441, 0.028093582019209862, 0.008476401679217815, 0.00896463356912136, 0.007736032363027334, 0.006790427025407553, 0.009148293174803257, 0.006993972696363926, 0.006845239549875259, 0.008285323157906532, 0.005908709950745106, 0.009022236801683903, 0.008929350413382053, 0.019131703302264214] |
1559782048 | [3] | [3] | [0.9018600583076477] | [0.04472490772604942, 0.0243248138576746, 0.011095968075096607, 0.9018600583076477, 0.02759535051882267, 0.01639750227332115, 0.010229885578155518, 0.01442675106227398, 0.03185756132006645, 0.01614650897681713, 0.014211165718734264, 0.011741148307919502, 0.00937943160533905, 0.013027109205722809, 0.008298314176499844, 0.010878310538828373, 0.012541105970740318, 0.009680655784904957, 0.014786235056817532, 0.01098882406949997, 0.014351315796375275] |
1560480983 | [3] | [6] | [0.5473132729530334] | [0.07873011380434036, 0.02117929421365261, 0.00462101586163044, 0.007679674308747053, 0.006423152983188629, 0.003737745573744178, 0.5473132729530334, 0.010648651979863644, 0.2306162267923355, 0.033958908170461655, 0.009718521498143673, 0.03945154696702957, 0.0667884573340416, 0.010746568441390991, 0.008459050208330154, 0.012853718362748623, 0.006122407037764788, 0.005631749518215656, 0.006334631238132715, 0.01488021295517683, 0.08340618759393692] |
demo:
目的:计算标签的整体acc、precision、recall。
如果想计算某一个类别的precision和recall,则在评价函数中加上这个参数:pos_label = [4],这里的4表示索引的第4列。
def calculate_acc_multi_label(read_path, sheet_name):
workbook = xlrd.open_workbook(read_path) # 打开工作簿
sheets = workbook.sheet_names() # 获取工作簿中的所有表格
worksheet = workbook.sheet_by_name(sheets[0]) # 获取工作簿中所有表格中的的第一个表格
print(worksheet.nrows)
print(worksheet.ncols)
true_label = []
pred_label = []
for i in range(1, 501):
label_str = worksheet.cell_value(i, 1)
label = [0 for x in range(0, 21)]
label_str = label_str[1:-1]
label_list = label_str.split(',')
for j in label_list:
label[int(j)] = 1
true_label.append(label)
pred_list = worksheet.cell_value(i, 2)
pred_lab = [0 for x in range(0, 21)]
# print('--length of pred: ', len(pred_list))
pred_list = pred_list[1:-1]
print('---index: {0} pred_list {1}: '.format(i, pred_list))
if pred_list != '':
pred_list = pred_list.split(',')
for g in pred_list:
pred_lab[int(g)] = 1
pred_label.append(pred_lab)
acc = metrics.accuracy_score(true_label, pred_label)
print('--acc:', acc)
# acc_list = hamming_score(true_label, pred_label)
# hamming = np.mean(acc_list)
# print('--hamming:', hamming)
precision = metrics.precision_score(true_label, pred_label, average='micro')
print('--precision:', precision)
recall = metrics.recall_score(true_label, pred_label, average='micro')
print('--recall:', recall)
f1 = metrics.f1_score(np.array(true_label), np.array(pred_label), average='micro')
print('--f1:', f1)
mcm = metrics.multilabel_confusion_matrix(true_label, pred_label)
tn = mcm[:, 0, 0]
tp = mcm[:, 1, 1]
fn = mcm[:, 1, 0]
fp = mcm[:, 0, 1]
print('tp: {0} fn: {1} fp: {2}'.format(tp, fn, fp))
sum_tp = sum(tp)
sum_fn = sum(fn)
sum_fp = sum(fp)
print('sum_tp: {0} sum_fn: {1} sum_fp: {2}'.format(sum_tp, sum_fn, sum_fp))
recall_list = tp / (tp + fn)
print('--recall_list', recall_list)
precision_list = tp / (tp + fp)
print('--precision_list', precision_list)
print('--precision_list length', len(precision_list))
print('---mcm :', mcm)
if __name__ == '__main__':
save_path = './multi_label_all_0.5_2.xlsx'
sheet_name = 'predict'
calculate_acc_multi_label(save_path, sheet_name)
这里打印了多标签的混淆矩阵,用来验证acc、precision、recall是怎么计算得到的,运行后,返回结果如下:
--acc: 0.704
--precision: 0.7960396039603961
--recall: 0.7052631578947368
--f1: 0.747906976744186
tp: [ 2 2 0 393 4 0 0 1 0 0 0 0 0 0 0 0 0 0
0 0 0] fn: [ 0 14 2 107 6 2 0 13 16 0 5 1 0 0 1 0 1 0
0 0 0] fp: [45 18 4 0 13 0 13 1 2 5 0 1 0 0 0 1 0 0 0 0 0]
sum_tp: 402 sum_fn: 168 sum_fp: 103
--recall_list [1. 0.125 0. 0.786 0.4 0.
nan 0.07142857 0. nan 0. 0.
nan nan 0. nan 0. nan
nan nan nan]
--precision_list [0.04255319 0.1 0. 1. 0.23529412 nan
0. 0.5 0. 0. nan 0.
nan nan nan 0. nan nan
nan nan nan]
--precision_list length 21
---mcm : [[[453 45]
[ 0 2]]
[[466 18]
[ 14 2]]
[[494 4]
[ 2 0]]
[[ 0 0]
[107 393]]
[[477 13]
[ 6 4]]
[[498 0]
[ 2 0]]
[[487 13]
[ 0 0]]
[[485 1]
[ 13 1]]
[[482 2]
[ 16 0]]
[[495 5]
[ 0 0]]
[[495 0]
[ 5 0]]
[[498 1]
[ 1 0]]
[[500 0]
[ 0 0]]
[[500 0]
[ 0 0]]
[[499 0]
[ 1 0]]
[[499 1]
[ 0 0]]
[[499 0]
[ 1 0]]
[[500 0]
[ 0 0]]
[[500 0]
[ 0 0]]
[[500 0]
[ 0 0]]
[[500 0]
[ 0 0]]]
可以用过混淆矩阵计算acc 、precision、recall等指标。