10-1 准确度的陷阱和混淆矩阵





10-2 精准率和召回率




10-3 实现混淆矩阵,精准率和召回率
Notbook 示例

Notbook 源码
[1]
import numpy as np
from sklearn import datasets
[2]
digits = datasets.load_digits()
X = digits.data
y = digits.target.copy()
y[digits.target==9] = 1
y[digits.target!=9] = 0
[3]
from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(X, y,random_state=666)
[4]
X_train.shape
(1347, 64)
[5]
from sklearn.linear_model import LogisticRegression
log_reg = LogisticRegression()
log_reg.fit(X_train,y_train)
log_reg.score(X_test,y_test)
F:\anaconda\lib\site-packages\sklearn\linear_model\_logistic.py:814: ConvergenceWarning: lbfgs failed to converge (status=1):
STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.
Increase the number of iterations (max_iter) or scale the data as shown in:
https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
n_iter_i = _check_optimize_result(
0.9755555555555555
[6]
y_log_predict = log_reg.predict(X_test)
[7]
def TN(y_ture, y_predict):
assert len(y_ture) == len(y_predict)
return np.sum((y_ture == 0) & (y_predict == 0))
TN(y_test,y_log_predict)
403
[8]
def FP(y_ture, y_predict):
assert len(y_ture) == len(y_predict)
return np.sum((y_ture == 0) & (y_predict == 1))
FP(y_test,y_log_predict)
2
[9]
def FN(y_ture, y_predict):
assert len(y_ture) == len(y_predict)
return np.sum((y_ture == 1) & (y_predict == 0))
FN(y_test,y_log_predict)
9
[10]
def TP(y_ture, y_predict):
assert len(y_ture) == len(y_predict)
return np.sum((y_ture == 1) & (y_predict == 1))
TP(y_test,y_log_predict)
36
[11]
def confusion_matrix(y_true, y_predict):
return np.array([
[TN(y_true,y_predict), FP(y_true,y_predict)],
[FN(y_true,y_predict), TP(y_true,y_predict)]
])
confusion_matrix(y_test, y_log_predict)
array([[403, 2],
[ 9, 36]])
[12]
def precision_score(y_true, y_predict):
tp = TP(y_true, y_predict)
fp = FP(y_true, y_predict)
try:
return tp / (tp + fp)
except:
return 0.0
precision_score(y_test, y_log_predict)
0.9473684210526315
[13]
def recall_score(y_true, y_predict):
tp = TP(y_true, y_predict)
fn = FN(y_true, y_predict)
try:
return tp / (tp + fn)
except:
return 0.0
recall_score(y_test, y_log_predict)
0.8
scikit-learn 中的混淆矩阵,精准率和召回率
[14]
from sklearn.metrics import confusion_matrix
confusion_matrix(y_test, y_log_predict)
array([[403, 2],
[ 9, 36]], dtype=int64)
[15]
from sklearn.metrics import precision_score
precision_score(y_test,y_log_predict)
0.9473684210526315
[16]
from sklearn.metrics import recall_score
recall_score(y_test,y_log_predict)
0.8
10-4 F1 Score





10-5 Precision-Recall 平衡



Notbook 示例

Notbook 源码
F1 Score
[18]
import numpy as np
import matplotlib.pyplot as plt
[4]
def f1_score(precision, recall):
try:
return 2 * precision * recall / ( precision + recall)
except:
return 0.0
[5]
precision = 0.5
recall =0.5
f1_score(precision,recall)
0.5
[6]
precision = 0.1
recall = 0.9
f1_score(precision,recall)
0.18000000000000002
[7]
precision = 0.0
recall = 0.9
f1_score(precision,recall)
0.0
[8]
from sklearn import datasets
digits = datasets.load_digits()
X = digits.data
y = digits.target.copy()
y[digits.target==9] = 1
y[digits.target!=9] = 0
[9]
from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(X, y,random_state=666)
[10]
from sklearn.linear_model import LogisticRegression
log_reg = LogisticRegression()
log_reg.fit(X_train,y_train)
log_reg.score(X_test,y_test)
F:\anaconda\lib\site-packages\sklearn\linear_model\_logistic.py:814: ConvergenceWarning: lbfgs failed to converge (status=1):
STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.
Increase the number of iterations (max_iter) or scale the data as shown in:
https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
n_iter_i = _check_optimize_result(
0.9755555555555555
[12]
y_predict = log_reg.predict(X_test)
[14]
from sklearn.metrics import confusion_matrix
confusion_matrix(y_test, y_predict)
array([[403, 2],
[ 9, 36]], dtype=int64)
[15]
from sklearn.metrics import precision_score
precision_score(y_test,y_predict)
0.9473684210526315
[16]
from sklearn.metrics import recall_score
recall_score(y_test,y_predict)
0.8
[17]
from sklearn.metrics import f1_score
f1_score(y_test, y_predict)
0.8674698795180723
[19]
log_reg.decision_function(X_test)
array([-21.39857276, -32.89731271, -16.41797156, -79.82318954,
-48.03305046, -24.18254017, -44.60990955, -24.24479014,
-1.14284305, -19.00457455, -65.82296325, -50.97066379,
-30.92082895, -45.94864685, -37.36152924, -29.51329291,
-36.92856241, -82.80968102, -37.63648469, -9.8788178 ,
-9.26807376, -85.25151511, -16.75031683, -45.3443087 ,
-5.02564992, -48.29794851, -11.65881308, -37.36076018,
-25.08299918, -13.59764839, -16.5953793 , -28.78598514,
-34.3678504 , -28.52337297, -8.11452445, -4.6022814 ,
-21.94247061, -21.87781719, -31.17562964, -23.36466695,
-26.90556959, -62.23610493, -37.68704357, -66.36559349,
-20.10364224, -16.68553543, -18.16727295, -21.5492968 ,
-28.96549149, -19.61417448, 2.41242539, 7.7293895 ,
-34.87176036, -42.70947089, -25.63234617, -34.75112951,
-7.59781243, -49.51333048, -51.52646722, 19.66201134,
-10.09725489, -32.0060884 , -11.49932898, -1.42857622,
-48.69518674, -43.87320098, -24.83993002, -19.60221328,
-36.64215638, -3.52332398, -4.44425929, -19.2097096 ,
-20.35743524, -40.89507478, -11.8601531 , -32.7541669 ,
-35.7587069 , -28.5992766 , -55.41729445, -18.82659602,
4.56820284, -16.46610285, -76.77533257, -58.24489386,
-30.24372047, -29.42228053, -33.41709641, -8.41820483,
-47.91658806, -65.49746283, -16.90883929, -22.17253788,
-11.28533349, -18.66745327, -69.22403985, -46.39517132,
-39.45322992, -35.92419637, -17.72138133, -62.96856734,
-16.85788403, -55.14488072, -28.77104338, -68.47963152,
-68.85398745, -6.50137137, -25.51784658, -38.31116618,
-27.46927833, -15.54375029, -27.47815541, -20.3332547 ,
12.07445747, -23.0874899 , -35.96861875, -29.87593015,
-68.95687582, -27.32891417, -54.23494371, -24.63214107,
-11.85499344, -47.3668394 , -2.75048074, -59.68909997,
-30.98860082, -8.98734123, -70.83680244, -56.97836911,
-20.07706325, -21.49966977, -68.28663666, -18.91058226,
-38.59829624, -57.36383144, -0.91081426, -22.51004028,
-22.66179993, -28.99910954, -32.78451092, -20.43310932,
-11.3535947 , 4.63057398, 6.26725227, 1.48867388,
-7.63736213, -39.24004802, 12.15620508, -74.5437931 ,
-75.08648846, -49.97467006, -11.63081865, -47.61958938,
-75.41232907, -29.89880625, -63.93514052, -7.26078617,
-6.64271099, -18.2199428 , -32.47674504, -17.93503126,
-43.33439089, -32.70727873, -34.29947784, -72.74689478,
-15.19084634, 11.48054014, -56.40994066, -6.03930048,
-48.38612896, -16.44647469, -2.13693844, -11.85713489,
-33.26559831, -51.34042787, -10.38651041, -17.18846078,
-5.23982411, -25.19373985, -15.70686294, 3.5534034 ,
-45.03772698, -12.58192379, -25.37999195, -16.56801256,
-22.17722688, -82.50131039, -5.8811552 , -20.25621041,
-20.46383207, -26.80997392, -25.98518361, -40.44912794,
-38.01122059, -26.9627282 , -23.75636279, -20.15726322,
-9.69213637, -19.6799691 , -42.49289639, -44.13469938,
-15.65386714, -64.03047268, -24.55648146, -56.30568399,
-13.01339393, -29.66652546, 3.89794499, -44.33546306,
-7.92245618, 1.14543666, -2.81814751, -11.92929586,
7.5086596 , -7.17718348, -46.39847023, -48.65871982,
-4.59959364, -19.05437356, -24.07254218, -48.76355552,
-15.01620526, -24.92137044, -16.69772054, -18.68326579,
-15.70208152, -16.86386928, -38.52705695, -31.09380281,
-9.37781861, -71.4453079 , -22.76526306, -14.43837784,
-23.08137726, -34.31916589, -0.89221103, -32.73888374,
-11.21723013, -18.6738182 , -8.21484026, -45.43305526,
-22.30560288, -62.38971913, -46.77028519, -65.15237525,
-33.22628484, -23.47536421, -28.51024714, -64.78914741,
1.45290051, -4.09358964, -25.64587602, -22.32038298,
-54.68656406, -16.3407006 , -12.06726537, -35.28199188,
-5.7391347 , -13.52396326, -72.2770459 , -6.16552202,
-1.16494995, -35.58095254, -24.15372831, -68.3152937 ,
14.76606277, -63.0626057 , 9.9115143 , -24.1477828 ,
-32.45732897, -14.38796233, -85.7282472 , -12.77864747,
8.99482139, -16.51791403, -36.67219629, -16.51511131,
-19.35718611, -32.583308 , -5.64342385, 7.68471894,
9.38946768, 5.85378475, -35.64899776, -12.98316031,
-54.42344306, -41.10888515, 5.63263711, -79.47912897,
-15.82650933, -19.23205602, -10.86309466, -42.52164565,
-19.81792269, -15.70492451, -17.99800508, -18.02255039,
-6.75867766, -20.78794591, -16.58125173, -70.42110518,
-9.21349451, -31.70399615, -19.67558207, -21.95918435,
-24.77110999, -16.38822309, -13.36794196, -22.93287663,
11.06093377, -15.37076191, -32.94045314, -13.74640562,
-50.35815794, -20.45538215, -56.2709184 , -28.68677373,
-21.86524573, -30.41664698, -69.26034763, -59.34711621,
14.34093357, 8.57797635, -25.66219805, 2.74054632,
4.9329685 , -19.66539612, -58.82488345, -10.00833742,
-28.80946298, -27.20346821, 6.28874155, -80.46777388,
-34.45717484, -50.28471677, -35.95066935, -48.6313621 ,
-18.01210551, -62.3428243 , -3.09974615, -25.2635612 ,
-64.10526345, -9.61660605, -21.76591374, 19.89900139,
-18.75262552, -4.46636384, -13.15019258, -21.64298339,
-43.10021867, -52.10329918, -28.53126446, -14.54900274,
-2.47647559, -6.12117544, 3.69187156, -15.0063578 ,
-40.85876851, -26.64359518, 14.10780389, -17.68798006,
15.18161223, -33.09641501, 5.26048113, -14.27034463,
-53.58418085, -50.04146827, -30.668069 , -38.05244113,
-23.29209606, -24.6960092 , -13.57354354, -22.62553141,
-27.2290141 , -19.64733979, -28.1768732 , -19.93558149,
-29.85262347, -11.28766344, -17.24377394, -24.0310721 ,
-24.35542295, 10.39150921, -17.21009704, -38.02155334,
-16.08422171, -37.57447399, -16.327524 , -69.12211344,
-33.67776297, -43.62662563, -26.61467625, -10.32511698,
-66.36070209, -31.9032331 , -45.56406403, -14.57833594,
-36.13656958, -14.94377141, -70.01819354, -11.35647733,
-40.86227952, -32.65545084, -19.77146533, -27.58157471,
-15.73466776, -31.57608305, -8.50558639, -21.38402622,
-34.07101343, -11.68747617, -36.42460337, -34.78640679,
-22.21781815, 4.77423291, -21.31044306, -4.45343862,
-20.8192745 , -32.26057776, -41.11472384, -25.0841837 ,
-19.76245188, -47.86598828, -30.89389022, -45.55549885,
-71.52150073, -6.25498279, -32.5635314 , 2.27397922,
11.93710255, 7.1181192 , -31.36293349, -63.9582492 ,
-23.78891268, -5.73651065, -32.42584299, -24.7138706 ,
-67.69974056, -32.8331123 , -33.60887574, -31.53192719,
-51.97754435, -22.54575078, -7.74388421, -17.30052337,
-25.78866235, -32.37585686, -29.48393512, -66.43195243,
-45.70161834, -16.05036959])
[20]
log_reg.decision_function(X_test)[:10]
array([-21.39857276, -32.89731271, -16.41797156, -79.82318954,
-48.03305046, -24.18254017, -44.60990955, -24.24479014,
-1.14284305, -19.00457455])
[21]
log_reg.predict(X_test)[:10]
array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
[22]
decision_scores = log_reg.decision_function(X_test)
[23]
np.min(decision_scores)
-85.72824719706493
[24]
np.max(decision_scores)
19.899001392471213
[25]
y_predict_2 = np.array(decision_scores >= 5, dtype='int')
[26]
confusion_matrix(y_test, y_predict_2)
array([[404, 1],
[ 21, 24]], dtype=int64)
[27]
precision_score(y_test,y_predict_2)
0.96
[28]
recall_score(y_test,y_predict_2)
0.5333333333333333
[29]
y_predict_3 = np.array(decision_scores >= -5, dtype='int')
[30]
confusion_matrix(y_test, y_predict_3)
array([[390, 15],
[ 5, 40]], dtype=int64)
[31]
precision_score(y_test,y_predict_3)
0.7272727272727273
[32]
recall_score(y_test,y_predict_3)
0.8888888888888888
10-6 scikit-learn 中的Precision-Recall曲线


Notbook 示例

Notbook 源码
[8]
import numpy as np
import matplotlib.pyplot as plt
from sklearn import datasets
[2]
digits = datasets.load_digits()
X = digits.data
y = digits.target.copy()
y[digits.target==9] = 1
y[digits.target!=9] = 0
[3]
from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(X, y,random_state=666)
[5]
from sklearn.linear_model import LogisticRegression
log_reg = LogisticRegression()
log_reg.fit(X_train,y_train)
decision_score = log_reg.decision_function(X_test)
F:\anaconda\lib\site-packages\sklearn\linear_model\_logistic.py:814: ConvergenceWarning: lbfgs failed to converge (status=1):
STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.
Increase the number of iterations (max_iter) or scale the data as shown in:
https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
n_iter_i = _check_optimize_result(
[6]
from sklearn.metrics import precision_score
from sklearn.metrics import recall_score
precisions = []
recallls = []
thresholds = np.arange(np.min(decision_score), np.max(decision_score),0.1)
for threshold in thresholds:
y_predict = np.array(decision_score >= threshold,dtype='int')
precisions.append(precision_score(y_test, y_predict))
recallls.append(recall_score(y_test,y_predict))
[9]
plt.plot(thresholds,precisions)
plt.plot(thresholds,recallls)
[<matplotlib.lines.Line2D at 0x1bf900a73a0>]
Precision-Recall 曲线
[10]
plt.plot(precisions, recallls)
[<matplotlib.lines.Line2D at 0x1bf9016b4c0>]
scikit-learn 中的Precision-Recall曲线
[11]
from sklearn.metrics import precision_recall_curve
precisions, recallls, thresholds = precision_recall_curve(y_test,decision_score)
[12]
precisions.shape
(151,)
[13]
recallls.shape
(151,)
[14]
thresholds.shape
(150,)
[15]
plt.plot(thresholds,precisions[:-1])
plt.plot(thresholds,recallls[:-1])
[<matplotlib.lines.Line2D at 0x1bf901f6fa0>]
[16]
plt.plot(precisions, recallls)
[<matplotlib.lines.Line2D at 0x1bf901e8160>]
10-7 ROC





Notbook 示例

Notbook 源码
ROC曲线
[1]
import numpy as np
import matplotlib.pyplot as plt
from sklearn import datasets
[2]
digits = datasets.load_digits()
X = digits.data
y = digits.target.copy()
y[digits.target==9] = 1
y[digits.target!=9] = 0
[3]
from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(X, y,random_state=666)
[4]
from sklearn.linear_model import LogisticRegression
log_reg = LogisticRegression()
log_reg.fit(X_train,y_train)
decision_score = log_reg.decision_function(X_test)
F:\anaconda\lib\site-packages\sklearn\linear_model\_logistic.py:814: ConvergenceWarning: lbfgs failed to converge (status=1):
STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.
Increase the number of iterations (max_iter) or scale the data as shown in:
https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
n_iter_i = _check_optimize_result(
[5]
from playML.metrics import FPR, TPR
fprs = []
tprs = []
thresholds = np.arange(np.min(decision_score), np.max(decision_score),0.1)
for threshold in thresholds:
y_predict = np.array(decision_score >= threshold,dtype='int')
fprs.append(FPR(y_test, y_predict))
tprs.append(TPR(y_test,y_predict))
[6]
plt.plot(fprs,tprs)
[<matplotlib.lines.Line2D at 0x1600c413820>]
scikit-learn中的ROC
[7]
from sklearn.metrics import roc_curve
fps, tps, thresholds = roc_curve(y_test,decision_score)
[8]
plt.plot(fprs,tprs)
[<matplotlib.lines.Line2D at 0x1600c190cd0>]
[9]
from sklearn.metrics import roc_auc_score
roc_auc_score(y_test,decision_score)
0.9823868312757201
10-8 多分类问题中的混淆矩阵
Notbook 示例

Notbook 源码
多分类问题中的混淆矩阵
[2]
import numpy as np
import matplotlib.pyplot as plt
from sklearn import datasets
[5]
digits = datasets.load_digits()
X = digits.data
y = digits.target.copy()
[6]
from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(X, y,random_state=666)
[8]
from sklearn.linear_model import LogisticRegression
log_reg = LogisticRegression()
log_reg.fit(X_train,y_train)
log_reg.score(X_test,y_test)
F:\anaconda\lib\site-packages\sklearn\linear_model\_logistic.py:814: ConvergenceWarning: lbfgs failed to converge (status=1):
STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.
Increase the number of iterations (max_iter) or scale the data as shown in:
https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
n_iter_i = _check_optimize_result(
0.9711111111111111
[9]
y_predict = log_reg.predict(X_test)
[10]
from sklearn.metrics import precision_score
precision_score(y_test,y_predict,average='micro')
0.9711111111111111
[11]
from sklearn.metrics import confusion_matrix
confusion_matrix(y_test,y_predict)
array([[46, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[ 0, 40, 0, 0, 0, 0, 0, 0, 0, 0],
[ 0, 0, 50, 0, 0, 0, 0, 0, 0, 0],
[ 0, 0, 1, 50, 0, 0, 0, 0, 1, 1],
[ 0, 0, 0, 0, 47, 0, 0, 0, 1, 0],
[ 0, 0, 0, 0, 0, 37, 0, 1, 0, 0],
[ 0, 0, 0, 0, 0, 1, 38, 0, 0, 0],
[ 0, 0, 0, 0, 0, 0, 0, 43, 0, 0],
[ 0, 0, 0, 0, 1, 2, 0, 1, 44, 0],
[ 0, 0, 0, 1, 0, 2, 0, 0, 0, 42]], dtype=int64)
[12]
cfm = confusion_matrix(y_test,y_predict)
plt.matshow(cfm, cmap=plt.cm.gray)
<matplotlib.image.AxesImage at 0x1cdcf0153d0>
[14]
row_sums = np.sum(cfm, axis=1)
err_matrix = cfm / row_sums
np.fill_diagonal(err_matrix,0)
err_matrix
array([[0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ],
[0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ],
[0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ],
[0. , 0. , 0.02 , 0. , 0. ,
0. , 0. , 0. , 0.02083333, 0.02222222],
[0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0.02083333, 0. ],
[0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0.02325581, 0. , 0. ],
[0. , 0. , 0. , 0. , 0. ,
0.02631579, 0. , 0. , 0. , 0. ],
[0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ],
[0. , 0. , 0. , 0. , 0.02083333,
0.05263158, 0. , 0.02325581, 0. , 0. ],
[0. , 0. , 0. , 0.01886792, 0. ,
0.05263158, 0. , 0. , 0. , 0. ]])
[15]
plt.matshow(err_matrix, cmap=plt.cm.gray)
<matplotlib.image.AxesImage at 0x1cdcf386550>