sklearn(五)计算acc:使用metrics.accuracy_score()计算分类的准确率

1.acc计算原理

sklearn中accuracy_score函数计算了准确率。

在二分类或者多分类中,预测得到的label,跟真实label比较,计算准确率。

在multilabel(多标签问题)分类中,该函数会返回子集的准确率。如果对于一个样本来说,必须严格匹配真实数据集中的label,整个集合的预测标签返回1.0;否则返回0.0.

2.acc的不适用场景:

正负样本不平衡的情况下,准确率这个评价指标有很大的缺陷。比如在互联网广告里面,点击的数量是很少的,一般只有千分之几,如果用acc,即使全部预测成负类(不点击)acc也有 99% 以上,没有意义。因此,单纯靠准确率来评价一个算法模型是远远不够科学全面的。在类别不平衡没那么太严重时,该指标具有一定的参考意义。

3.metrics.accuracy_score()的使用方法

不管是二分类还是多分类,还是多标签问题,计算公式都为:

这里写图片描述

只是在多标签问题中,TP、TN要求更加严格,必须严格匹配真实数据集中的label。

sklearn.metrics.accuracy_score(y_true, y_pred, *, normalize=True, sample_weight=None)

输入参数:

y_true:真是标签。二分类和多分类情况下是一列,多标签情况下是标签的索引。

y_pred:预测标签。二分类和多分类情况下是一列,多标签情况下是标签的索引。

normalize:bool, optional (default=True),如果是false,正确分类的样本的数目(int);如果为true,返回正确分类的样本的比例,必须严格匹配真实数据集中的label,才为1,否则为0。

sample_weight:array-like of shape (n_samples,), default=None。Sample weights.

输出:

如果normalize == True,返回正确分类的样本的比例,否则返回正确分类的样本的数目(int)。

4.例子 

举一个多标签的例子,这里假设有21个标签。

数据格式:预测label是有阈值硬阶段得来,比如当预测得分大于0.5,则在这个索引下的label为1,否则为0。

idlabelpred_labelpred_label_scorepred_scores(模型输出的scores)
1559592780[3][3][0.9060243964195251][0.03569700941443443, 0.025016790255904198, 0.010681516490876675, 0.9060243964195251, 0.03405195102095604, 0.01652703806757927, 0.01057326141744852, 0.015285834670066833, 0.03219904750585556, 0.01710071600973606, 0.015052232891321182, 0.012746844440698624, 0.009399563074111938, 0.012753037735819817, 0.008887830190360546, 0.011201461777091026, 0.013154321350157261, 0.010007181204855442, 0.015232570469379425, 0.011832496151328087, 0.014289622195065022]
1559950270[3][][][0.0441354475915432, 0.07238972187042236, 0.011645170859992504, 0.007589259184896946, 0.25604453682899475, 0.08702245354652405, 0.27572867274284363, 0.00486581027507782, 0.01071715448051691, 0.010638655163347721, 0.005942077841609716, 0.03388604149222374, 0.003174690529704094, 0.006336248945444822, 0.007447054609656334, 0.004069846123456955, 0.06864038109779358, 0.003221432212740183, 0.010166178457438946, 0.014550245366990566, 0.018491217866539955]
1559394894[3][3][0.6821054816246033][0.2968560457229614, 0.0307493656873703, 0.005526685621589422, 0.6821054816246033, 0.019207751378417015, 0.011433916166424751, 0.00833720900118351, 0.011756493709981441, 0.028093582019209862, 0.008476401679217815, 0.00896463356912136, 0.007736032363027334, 0.006790427025407553, 0.009148293174803257, 0.006993972696363926, 0.006845239549875259, 0.008285323157906532, 0.005908709950745106, 0.009022236801683903, 0.008929350413382053, 0.019131703302264214]
1559782048[3][3][0.9018600583076477][0.04472490772604942, 0.0243248138576746, 0.011095968075096607, 0.9018600583076477, 0.02759535051882267, 0.01639750227332115, 0.010229885578155518, 0.01442675106227398, 0.03185756132006645, 0.01614650897681713, 0.014211165718734264, 0.011741148307919502, 0.00937943160533905, 0.013027109205722809, 0.008298314176499844, 0.010878310538828373, 0.012541105970740318, 0.009680655784904957, 0.014786235056817532, 0.01098882406949997, 0.014351315796375275]
1560480983[3][6][0.5473132729530334][0.07873011380434036, 0.02117929421365261, 0.00462101586163044, 0.007679674308747053, 0.006423152983188629, 0.003737745573744178, 0.5473132729530334, 0.010648651979863644, 0.2306162267923355, 0.033958908170461655, 0.009718521498143673, 0.03945154696702957, 0.0667884573340416, 0.010746568441390991, 0.008459050208330154, 0.012853718362748623, 0.006122407037764788, 0.005631749518215656, 0.006334631238132715, 0.01488021295517683, 0.08340618759393692]

demo:

目的:计算标签的整体acc、precision、recall。

如果想计算某一个类别的precision和recall,则在评价函数中加上这个参数:pos_label = [4],这里的4表示索引的第4列。

def calculate_acc_multi_label(read_path, sheet_name):
    workbook = xlrd.open_workbook(read_path)  # 打开工作簿
    sheets = workbook.sheet_names()  # 获取工作簿中的所有表格
    worksheet = workbook.sheet_by_name(sheets[0])  # 获取工作簿中所有表格中的的第一个表格
    print(worksheet.nrows)
    print(worksheet.ncols)
    true_label = []
    pred_label = []
    for i in range(1, 501):
        label_str = worksheet.cell_value(i, 1)
        label = [0 for x in range(0, 21)]
        label_str = label_str[1:-1]
        label_list = label_str.split(',')
        for j in label_list:
            label[int(j)] = 1
        true_label.append(label)

        pred_list = worksheet.cell_value(i, 2)
        pred_lab = [0 for x in range(0, 21)]
        # print('--length of pred: ', len(pred_list))
        pred_list = pred_list[1:-1]
        print('---index: {0}  pred_list {1}: '.format(i, pred_list))
        if pred_list != '':
            pred_list = pred_list.split(',')
            for g in pred_list:
                pred_lab[int(g)] = 1
        pred_label.append(pred_lab)
    acc = metrics.accuracy_score(true_label, pred_label)
    print('--acc:', acc)
    # acc_list = hamming_score(true_label, pred_label)
    # hamming = np.mean(acc_list)
    # print('--hamming:', hamming)
    precision = metrics.precision_score(true_label, pred_label,  average='micro')
    print('--precision:', precision)
    recall = metrics.recall_score(true_label, pred_label, average='micro')
    print('--recall:', recall)
    f1 = metrics.f1_score(np.array(true_label), np.array(pred_label), average='micro')
    print('--f1:', f1)

    mcm = metrics.multilabel_confusion_matrix(true_label, pred_label)
    tn = mcm[:, 0, 0]
    tp = mcm[:, 1, 1]
    fn = mcm[:, 1, 0]
    fp = mcm[:, 0, 1]
    print('tp: {0}  fn: {1}  fp: {2}'.format(tp, fn, fp))

    sum_tp = sum(tp)
    sum_fn = sum(fn)
    sum_fp = sum(fp)
    print('sum_tp: {0}  sum_fn: {1}  sum_fp: {2}'.format(sum_tp, sum_fn, sum_fp))
    recall_list = tp / (tp + fn)
    print('--recall_list', recall_list)
    precision_list = tp / (tp + fp)
    print('--precision_list', precision_list)
    print('--precision_list length', len(precision_list))
    print('---mcm :', mcm)
if __name__ == '__main__':

    save_path = './multi_label_all_0.5_2.xlsx'
    sheet_name = 'predict'
    calculate_acc_multi_label(save_path, sheet_name)

这里打印了多标签的混淆矩阵,用来验证acc、precision、recall是怎么计算得到的,运行后,返回结果如下:

--acc: 0.704
--precision: 0.7960396039603961
--recall: 0.7052631578947368
--f1: 0.747906976744186
tp: [  2   2   0 393   4   0   0   1   0   0   0   0   0   0   0   0   0   0
   0   0   0]  fn: [  0  14   2 107   6   2   0  13  16   0   5   1   0   0   1   0   1   0
   0   0   0]  fp: [45 18  4  0 13  0 13  1  2  5  0  1  0  0  0  1  0  0  0  0  0]
sum_tp: 402  sum_fn: 168  sum_fp: 103
--recall_list [1.         0.125      0.         0.786      0.4        0.
        nan 0.07142857 0.                nan 0.         0.
        nan        nan 0.                nan 0.                nan
        nan        nan        nan]
--precision_list [0.04255319 0.1        0.         1.         0.23529412        nan
 0.         0.5        0.         0.                nan 0.
        nan        nan        nan 0.                nan        nan
        nan        nan        nan]
--precision_list length 21
---mcm : [[[453  45]
  [  0   2]]
 [[466  18]
  [ 14   2]]
 [[494   4]
  [  2   0]]
 [[  0   0]
  [107 393]]
 [[477  13]
  [  6   4]]
 [[498   0]
  [  2   0]]
 [[487  13]
  [  0   0]]
 [[485   1]
  [ 13   1]]
 [[482   2]
  [ 16   0]]
 [[495   5]
  [  0   0]]
 [[495   0]
  [  5   0]]
 [[498   1]
  [  1   0]]
 [[500   0]
  [  0   0]]
 [[500   0]
  [  0   0]]
 [[499   0]
  [  1   0]]
 [[499   1]
  [  0   0]]
 [[499   0]
  [  1   0]]
 [[500   0]
  [  0   0]]
 [[500   0]
  [  0   0]]
 [[500   0]
  [  0   0]]
 [[500   0]
  [  0   0]]]

可以用过混淆矩阵计算acc 、precision、recall等指标。

 

评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值