第10章 评价分类结果

10-1 准确度的陷阱和混淆矩阵

 10-2 精准率和召回率

10-3 实现混淆矩阵,精准率和召回率

Notbook 示例

Notbook 源码

[1]
import numpy as np
from sklearn import datasets
[2]
digits = datasets.load_digits()
X = digits.data
y = digits.target.copy()

y[digits.target==9] = 1
y[digits.target!=9] = 0
[3]
from sklearn.model_selection import train_test_split

X_train, X_test, y_train, y_test = train_test_split(X, y,random_state=666)
[4]
X_train.shape
(1347, 64)
[5]
from sklearn.linear_model import LogisticRegression

log_reg = LogisticRegression()
log_reg.fit(X_train,y_train)
log_reg.score(X_test,y_test)
F:\anaconda\lib\site-packages\sklearn\linear_model\_logistic.py:814: ConvergenceWarning: lbfgs failed to converge (status=1):
STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(

0.9755555555555555
[6]
y_log_predict = log_reg.predict(X_test)
[7]
def TN(y_ture, y_predict):
    assert len(y_ture) == len(y_predict)
    return np.sum((y_ture == 0) & (y_predict == 0))

TN(y_test,y_log_predict)
403
[8]
def FP(y_ture, y_predict):
    assert len(y_ture) == len(y_predict)
    return np.sum((y_ture == 0) & (y_predict == 1))

FP(y_test,y_log_predict)
2
[9]
def FN(y_ture, y_predict):
    assert len(y_ture) == len(y_predict)
    return np.sum((y_ture == 1) & (y_predict == 0))

FN(y_test,y_log_predict)
9
[10]
def TP(y_ture, y_predict):
    assert len(y_ture) == len(y_predict)
    return np.sum((y_ture == 1) & (y_predict == 1))

TP(y_test,y_log_predict)
36
[11]
def confusion_matrix(y_true, y_predict):
    return np.array([
        [TN(y_true,y_predict), FP(y_true,y_predict)],
        [FN(y_true,y_predict), TP(y_true,y_predict)]
    ])
confusion_matrix(y_test, y_log_predict)
array([[403,   2],
       [  9,  36]])
[12]
def precision_score(y_true, y_predict):
    tp = TP(y_true, y_predict)
    fp = FP(y_true, y_predict)
    try:
        return tp / (tp + fp)
    except:
        return 0.0
    
precision_score(y_test, y_log_predict)
0.9473684210526315
[13]
def recall_score(y_true, y_predict):
    tp = TP(y_true, y_predict)
    fn = FN(y_true, y_predict)
    try:
        return tp / (tp + fn)
    except:
        return 0.0
    
recall_score(y_test, y_log_predict)
0.8
scikit-learn 中的混淆矩阵,精准率和召回率
[14]
from sklearn.metrics import confusion_matrix

confusion_matrix(y_test, y_log_predict)
array([[403,   2],
       [  9,  36]], dtype=int64)
[15]
from sklearn.metrics import precision_score

precision_score(y_test,y_log_predict)
0.9473684210526315
[16]
from sklearn.metrics import recall_score

recall_score(y_test,y_log_predict)
0.8

10-4 F1 Score

10-5 Precision-Recall 平衡

 Notbook 示例

Notbook 源码

F1 Score
[18]
import numpy as np
import matplotlib.pyplot as plt
[4]
def f1_score(precision, recall):
    try:
        return 2 * precision * recall / ( precision + recall)
    except:
        return 0.0
[5]
precision = 0.5
recall =0.5
f1_score(precision,recall)
0.5
[6]
precision = 0.1
recall = 0.9
f1_score(precision,recall)
0.18000000000000002
[7]
precision = 0.0
recall = 0.9
f1_score(precision,recall)
0.0
[8]
from sklearn import datasets

digits = datasets.load_digits()
X = digits.data
y = digits.target.copy()

y[digits.target==9] = 1
y[digits.target!=9] = 0
[9]
from sklearn.model_selection import train_test_split

X_train, X_test, y_train, y_test = train_test_split(X, y,random_state=666)
[10]
from sklearn.linear_model import LogisticRegression

log_reg = LogisticRegression()
log_reg.fit(X_train,y_train)
log_reg.score(X_test,y_test)
F:\anaconda\lib\site-packages\sklearn\linear_model\_logistic.py:814: ConvergenceWarning: lbfgs failed to converge (status=1):
STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(

0.9755555555555555
[12]
y_predict = log_reg.predict(X_test)
[14]
from sklearn.metrics import confusion_matrix

confusion_matrix(y_test, y_predict)
array([[403,   2],
       [  9,  36]], dtype=int64)
[15]
from sklearn.metrics import precision_score

precision_score(y_test,y_predict)
0.9473684210526315
[16]
from sklearn.metrics import recall_score

recall_score(y_test,y_predict)
0.8
[17]
from sklearn.metrics import f1_score

f1_score(y_test, y_predict)
0.8674698795180723
[19]
log_reg.decision_function(X_test)
array([-21.39857276, -32.89731271, -16.41797156, -79.82318954,
       -48.03305046, -24.18254017, -44.60990955, -24.24479014,
        -1.14284305, -19.00457455, -65.82296325, -50.97066379,
       -30.92082895, -45.94864685, -37.36152924, -29.51329291,
       -36.92856241, -82.80968102, -37.63648469,  -9.8788178 ,
        -9.26807376, -85.25151511, -16.75031683, -45.3443087 ,
        -5.02564992, -48.29794851, -11.65881308, -37.36076018,
       -25.08299918, -13.59764839, -16.5953793 , -28.78598514,
       -34.3678504 , -28.52337297,  -8.11452445,  -4.6022814 ,
       -21.94247061, -21.87781719, -31.17562964, -23.36466695,
       -26.90556959, -62.23610493, -37.68704357, -66.36559349,
       -20.10364224, -16.68553543, -18.16727295, -21.5492968 ,
       -28.96549149, -19.61417448,   2.41242539,   7.7293895 ,
       -34.87176036, -42.70947089, -25.63234617, -34.75112951,
        -7.59781243, -49.51333048, -51.52646722,  19.66201134,
       -10.09725489, -32.0060884 , -11.49932898,  -1.42857622,
       -48.69518674, -43.87320098, -24.83993002, -19.60221328,
       -36.64215638,  -3.52332398,  -4.44425929, -19.2097096 ,
       -20.35743524, -40.89507478, -11.8601531 , -32.7541669 ,
       -35.7587069 , -28.5992766 , -55.41729445, -18.82659602,
         4.56820284, -16.46610285, -76.77533257, -58.24489386,
       -30.24372047, -29.42228053, -33.41709641,  -8.41820483,
       -47.91658806, -65.49746283, -16.90883929, -22.17253788,
       -11.28533349, -18.66745327, -69.22403985, -46.39517132,
       -39.45322992, -35.92419637, -17.72138133, -62.96856734,
       -16.85788403, -55.14488072, -28.77104338, -68.47963152,
       -68.85398745,  -6.50137137, -25.51784658, -38.31116618,
       -27.46927833, -15.54375029, -27.47815541, -20.3332547 ,
        12.07445747, -23.0874899 , -35.96861875, -29.87593015,
       -68.95687582, -27.32891417, -54.23494371, -24.63214107,
       -11.85499344, -47.3668394 ,  -2.75048074, -59.68909997,
       -30.98860082,  -8.98734123, -70.83680244, -56.97836911,
       -20.07706325, -21.49966977, -68.28663666, -18.91058226,
       -38.59829624, -57.36383144,  -0.91081426, -22.51004028,
       -22.66179993, -28.99910954, -32.78451092, -20.43310932,
       -11.3535947 ,   4.63057398,   6.26725227,   1.48867388,
        -7.63736213, -39.24004802,  12.15620508, -74.5437931 ,
       -75.08648846, -49.97467006, -11.63081865, -47.61958938,
       -75.41232907, -29.89880625, -63.93514052,  -7.26078617,
        -6.64271099, -18.2199428 , -32.47674504, -17.93503126,
       -43.33439089, -32.70727873, -34.29947784, -72.74689478,
       -15.19084634,  11.48054014, -56.40994066,  -6.03930048,
       -48.38612896, -16.44647469,  -2.13693844, -11.85713489,
       -33.26559831, -51.34042787, -10.38651041, -17.18846078,
        -5.23982411, -25.19373985, -15.70686294,   3.5534034 ,
       -45.03772698, -12.58192379, -25.37999195, -16.56801256,
       -22.17722688, -82.50131039,  -5.8811552 , -20.25621041,
       -20.46383207, -26.80997392, -25.98518361, -40.44912794,
       -38.01122059, -26.9627282 , -23.75636279, -20.15726322,
        -9.69213637, -19.6799691 , -42.49289639, -44.13469938,
       -15.65386714, -64.03047268, -24.55648146, -56.30568399,
       -13.01339393, -29.66652546,   3.89794499, -44.33546306,
        -7.92245618,   1.14543666,  -2.81814751, -11.92929586,
         7.5086596 ,  -7.17718348, -46.39847023, -48.65871982,
        -4.59959364, -19.05437356, -24.07254218, -48.76355552,
       -15.01620526, -24.92137044, -16.69772054, -18.68326579,
       -15.70208152, -16.86386928, -38.52705695, -31.09380281,
        -9.37781861, -71.4453079 , -22.76526306, -14.43837784,
       -23.08137726, -34.31916589,  -0.89221103, -32.73888374,
       -11.21723013, -18.6738182 ,  -8.21484026, -45.43305526,
       -22.30560288, -62.38971913, -46.77028519, -65.15237525,
       -33.22628484, -23.47536421, -28.51024714, -64.78914741,
         1.45290051,  -4.09358964, -25.64587602, -22.32038298,
       -54.68656406, -16.3407006 , -12.06726537, -35.28199188,
        -5.7391347 , -13.52396326, -72.2770459 ,  -6.16552202,
        -1.16494995, -35.58095254, -24.15372831, -68.3152937 ,
        14.76606277, -63.0626057 ,   9.9115143 , -24.1477828 ,
       -32.45732897, -14.38796233, -85.7282472 , -12.77864747,
         8.99482139, -16.51791403, -36.67219629, -16.51511131,
       -19.35718611, -32.583308  ,  -5.64342385,   7.68471894,
         9.38946768,   5.85378475, -35.64899776, -12.98316031,
       -54.42344306, -41.10888515,   5.63263711, -79.47912897,
       -15.82650933, -19.23205602, -10.86309466, -42.52164565,
       -19.81792269, -15.70492451, -17.99800508, -18.02255039,
        -6.75867766, -20.78794591, -16.58125173, -70.42110518,
        -9.21349451, -31.70399615, -19.67558207, -21.95918435,
       -24.77110999, -16.38822309, -13.36794196, -22.93287663,
        11.06093377, -15.37076191, -32.94045314, -13.74640562,
       -50.35815794, -20.45538215, -56.2709184 , -28.68677373,
       -21.86524573, -30.41664698, -69.26034763, -59.34711621,
        14.34093357,   8.57797635, -25.66219805,   2.74054632,
         4.9329685 , -19.66539612, -58.82488345, -10.00833742,
       -28.80946298, -27.20346821,   6.28874155, -80.46777388,
       -34.45717484, -50.28471677, -35.95066935, -48.6313621 ,
       -18.01210551, -62.3428243 ,  -3.09974615, -25.2635612 ,
       -64.10526345,  -9.61660605, -21.76591374,  19.89900139,
       -18.75262552,  -4.46636384, -13.15019258, -21.64298339,
       -43.10021867, -52.10329918, -28.53126446, -14.54900274,
        -2.47647559,  -6.12117544,   3.69187156, -15.0063578 ,
       -40.85876851, -26.64359518,  14.10780389, -17.68798006,
        15.18161223, -33.09641501,   5.26048113, -14.27034463,
       -53.58418085, -50.04146827, -30.668069  , -38.05244113,
       -23.29209606, -24.6960092 , -13.57354354, -22.62553141,
       -27.2290141 , -19.64733979, -28.1768732 , -19.93558149,
       -29.85262347, -11.28766344, -17.24377394, -24.0310721 ,
       -24.35542295,  10.39150921, -17.21009704, -38.02155334,
       -16.08422171, -37.57447399, -16.327524  , -69.12211344,
       -33.67776297, -43.62662563, -26.61467625, -10.32511698,
       -66.36070209, -31.9032331 , -45.56406403, -14.57833594,
       -36.13656958, -14.94377141, -70.01819354, -11.35647733,
       -40.86227952, -32.65545084, -19.77146533, -27.58157471,
       -15.73466776, -31.57608305,  -8.50558639, -21.38402622,
       -34.07101343, -11.68747617, -36.42460337, -34.78640679,
       -22.21781815,   4.77423291, -21.31044306,  -4.45343862,
       -20.8192745 , -32.26057776, -41.11472384, -25.0841837 ,
       -19.76245188, -47.86598828, -30.89389022, -45.55549885,
       -71.52150073,  -6.25498279, -32.5635314 ,   2.27397922,
        11.93710255,   7.1181192 , -31.36293349, -63.9582492 ,
       -23.78891268,  -5.73651065, -32.42584299, -24.7138706 ,
       -67.69974056, -32.8331123 , -33.60887574, -31.53192719,
       -51.97754435, -22.54575078,  -7.74388421, -17.30052337,
       -25.78866235, -32.37585686, -29.48393512, -66.43195243,
       -45.70161834, -16.05036959])
[20]
log_reg.decision_function(X_test)[:10]
array([-21.39857276, -32.89731271, -16.41797156, -79.82318954,
       -48.03305046, -24.18254017, -44.60990955, -24.24479014,
        -1.14284305, -19.00457455])
[21]
log_reg.predict(X_test)[:10]
array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
[22]
decision_scores = log_reg.decision_function(X_test)
[23]
np.min(decision_scores)
-85.72824719706493
[24]
np.max(decision_scores)
19.899001392471213
[25]
y_predict_2 = np.array(decision_scores >= 5, dtype='int')
[26]
confusion_matrix(y_test, y_predict_2)
array([[404,   1],
       [ 21,  24]], dtype=int64)
[27]
precision_score(y_test,y_predict_2)
0.96
[28]
recall_score(y_test,y_predict_2)
0.5333333333333333
[29]
y_predict_3 = np.array(decision_scores >= -5, dtype='int')
[30]
confusion_matrix(y_test, y_predict_3)
array([[390,  15],
       [  5,  40]], dtype=int64)
[31]
precision_score(y_test,y_predict_3)
0.7272727272727273
[32]
recall_score(y_test,y_predict_3)
0.8888888888888888

10-6 scikit-learn 中的Precision-Recall曲线

Notbook 示例

Notbook 源码

[8]
import numpy as np
import matplotlib.pyplot as plt
from sklearn import datasets
[2]
digits = datasets.load_digits()
X = digits.data
y = digits.target.copy()

y[digits.target==9] = 1
y[digits.target!=9] = 0
[3]
from sklearn.model_selection import train_test_split

X_train, X_test, y_train, y_test = train_test_split(X, y,random_state=666)
[5]
from sklearn.linear_model import LogisticRegression

log_reg = LogisticRegression()
log_reg.fit(X_train,y_train)
decision_score = log_reg.decision_function(X_test)
F:\anaconda\lib\site-packages\sklearn\linear_model\_logistic.py:814: ConvergenceWarning: lbfgs failed to converge (status=1):
STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(

[6]
from sklearn.metrics import precision_score
from sklearn.metrics import recall_score

precisions = []
recallls = []
thresholds = np.arange(np.min(decision_score), np.max(decision_score),0.1)
for threshold in thresholds:
    y_predict = np.array(decision_score >= threshold,dtype='int')
    precisions.append(precision_score(y_test, y_predict))
    recallls.append(recall_score(y_test,y_predict))
[9]
plt.plot(thresholds,precisions)
plt.plot(thresholds,recallls)
[<matplotlib.lines.Line2D at 0x1bf900a73a0>]

Precision-Recall 曲线
[10]
plt.plot(precisions, recallls)

[<matplotlib.lines.Line2D at 0x1bf9016b4c0>]

scikit-learn 中的Precision-Recall曲线
[11]
from sklearn.metrics import precision_recall_curve

precisions, recallls, thresholds = precision_recall_curve(y_test,decision_score)
[12]
precisions.shape
(151,)
[13]
recallls.shape
(151,)
[14]
thresholds.shape
(150,)
[15]
plt.plot(thresholds,precisions[:-1])
plt.plot(thresholds,recallls[:-1])
[<matplotlib.lines.Line2D at 0x1bf901f6fa0>]

[16]
plt.plot(precisions, recallls)
[<matplotlib.lines.Line2D at 0x1bf901e8160>]

10-7 ROC

Notbook 示例

Notbook 源码

ROC曲线
[1]
import numpy as np
import matplotlib.pyplot as plt
from sklearn import datasets
[2]
digits = datasets.load_digits()
X = digits.data
y = digits.target.copy()

y[digits.target==9] = 1
y[digits.target!=9] = 0
[3]
from sklearn.model_selection import train_test_split

X_train, X_test, y_train, y_test = train_test_split(X, y,random_state=666)
[4]
from sklearn.linear_model import LogisticRegression

log_reg = LogisticRegression()
log_reg.fit(X_train,y_train)
decision_score = log_reg.decision_function(X_test)
F:\anaconda\lib\site-packages\sklearn\linear_model\_logistic.py:814: ConvergenceWarning: lbfgs failed to converge (status=1):
STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(

[5]
from playML.metrics import FPR, TPR

fprs = []
tprs = []
thresholds = np.arange(np.min(decision_score), np.max(decision_score),0.1)
for threshold in thresholds:
    y_predict = np.array(decision_score >= threshold,dtype='int')
    fprs.append(FPR(y_test, y_predict))
    tprs.append(TPR(y_test,y_predict))
[6]
plt.plot(fprs,tprs)
[<matplotlib.lines.Line2D at 0x1600c413820>]

scikit-learn中的ROC
[7]
from sklearn.metrics import roc_curve

fps, tps, thresholds = roc_curve(y_test,decision_score)
[8]
plt.plot(fprs,tprs)
[<matplotlib.lines.Line2D at 0x1600c190cd0>]

[9]
from sklearn.metrics import roc_auc_score

roc_auc_score(y_test,decision_score)
0.9823868312757201

10-8 多分类问题中的混淆矩阵

Notbook 示例

Notbook 源码

多分类问题中的混淆矩阵
[2]
import numpy as np
import matplotlib.pyplot as plt
from sklearn import datasets
[5]
digits = datasets.load_digits()
X = digits.data
y = digits.target.copy()
[6]
from sklearn.model_selection import train_test_split

X_train, X_test, y_train, y_test = train_test_split(X, y,random_state=666)
[8]
from sklearn.linear_model import LogisticRegression

log_reg = LogisticRegression()
log_reg.fit(X_train,y_train)
log_reg.score(X_test,y_test)
F:\anaconda\lib\site-packages\sklearn\linear_model\_logistic.py:814: ConvergenceWarning: lbfgs failed to converge (status=1):
STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(

0.9711111111111111
[9]
y_predict = log_reg.predict(X_test)
[10]
from sklearn.metrics import precision_score

precision_score(y_test,y_predict,average='micro')
0.9711111111111111
[11]
from sklearn.metrics import confusion_matrix

confusion_matrix(y_test,y_predict)
array([[46,  0,  0,  0,  0,  0,  0,  0,  0,  0],
       [ 0, 40,  0,  0,  0,  0,  0,  0,  0,  0],
       [ 0,  0, 50,  0,  0,  0,  0,  0,  0,  0],
       [ 0,  0,  1, 50,  0,  0,  0,  0,  1,  1],
       [ 0,  0,  0,  0, 47,  0,  0,  0,  1,  0],
       [ 0,  0,  0,  0,  0, 37,  0,  1,  0,  0],
       [ 0,  0,  0,  0,  0,  1, 38,  0,  0,  0],
       [ 0,  0,  0,  0,  0,  0,  0, 43,  0,  0],
       [ 0,  0,  0,  0,  1,  2,  0,  1, 44,  0],
       [ 0,  0,  0,  1,  0,  2,  0,  0,  0, 42]], dtype=int64)
[12]
cfm = confusion_matrix(y_test,y_predict)
plt.matshow(cfm, cmap=plt.cm.gray)
<matplotlib.image.AxesImage at 0x1cdcf0153d0>

[14]
row_sums = np.sum(cfm, axis=1)
err_matrix = cfm / row_sums
np.fill_diagonal(err_matrix,0)
err_matrix
array([[0.        , 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.        , 0.        , 0.        ],
       [0.        , 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.        , 0.        , 0.        ],
       [0.        , 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.        , 0.        , 0.        ],
       [0.        , 0.        , 0.02      , 0.        , 0.        ,
        0.        , 0.        , 0.        , 0.02083333, 0.02222222],
       [0.        , 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.        , 0.02083333, 0.        ],
       [0.        , 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.02325581, 0.        , 0.        ],
       [0.        , 0.        , 0.        , 0.        , 0.        ,
        0.02631579, 0.        , 0.        , 0.        , 0.        ],
       [0.        , 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.        , 0.        , 0.        ],
       [0.        , 0.        , 0.        , 0.        , 0.02083333,
        0.05263158, 0.        , 0.02325581, 0.        , 0.        ],
       [0.        , 0.        , 0.        , 0.01886792, 0.        ,
        0.05263158, 0.        , 0.        , 0.        , 0.        ]])
[15]
plt.matshow(err_matrix, cmap=plt.cm.gray)
<matplotlib.image.AxesImage at 0x1cdcf386550>

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值