KNN最优K值的选取——学习曲线
方法1:网格(暴力)搜索
import numpy as np
from sklearn. neighbors import KNeighborsClassifier
from sklearn. datasets import load_breast_cancer
from sklearn. model_selection import train_test_split
import matplotlib. pyplot as plt
data = load_breast_cancer( )
X = data. data
y = data. target
X_train, X_test, y_train, y_test = train_test_split( X, y,
test_size= 0.3 , random_state= 420 )
代码的目的是通过循环改变 K-Nearest Neighbors(KNN)算法中的参数 k(即邻居数),并评估不同 k 值下模型在训练集和测试集上的准确率
train_score_list = [ ]
test_score_list = [ ]
for k in range ( 1 , 21 ) :
knn = KNeighborsClassifier( n_neighbors= k)
knn. fit( X_train, y_train)
train_score_list. append( knn. score( X_train, y_train) )
test_score_list. append( knn. score( X_test, y_test) )
print ( train_score_list)
[1.0, 0.9723618090452262, 0.9623115577889447, 0.9623115577889447, 0.9547738693467337, 0.949748743718593, 0.9422110552763819, 0.9447236180904522, 0.9422110552763819, 0.9447236180904522, 0.9447236180904522, 0.9447236180904522, 0.9396984924623115, 0.9371859296482412, 0.9371859296482412, 0.9371859296482412, 0.9321608040201005, 0.9346733668341709, 0.9296482412060302, 0.9346733668341709]
print ( test_score_list)
[0.8888888888888888, 0.8771929824561403, 0.9122807017543859, 0.9064327485380117, 0.9181286549707602, 0.9181286549707602, 0.9298245614035088, 0.935672514619883, 0.9298245614035088, 0.935672514619883, 0.9298245614035088, 0.935672514619883, 0.9239766081871345, 0.9239766081871345, 0.9298245614035088, 0.9298245614035088, 0.935672514619883, 0.935672514619883, 0.9298245614035088, 0.9239766081871345]
plt. figure( dpi= 200 )
plt. plot( range ( 1 , 21 ) , train_score_list, label= 'train_score' ) ;
plt. plot( range ( 1 , 21 ) , test_score_list, label= 'test_score' ) ;
plt. xlabel( 'k_value' )
plt. ylabel( 'socre' )
plt. legend( )
plt. show( )
max ( test_score_list)
0.935672514619883
print ( np. argmax( test_score_list) + 1 )
8
方法2:交叉验证
import numpy as np
from sklearn. neighbors import KNeighborsClassifier
from sklearn. datasets import load_breast_cancer
from sklearn. model_selection import train_test_split
import matplotlib. pyplot as plt
from sklearn. model_selection import cross_val_score
data = load_breast_cancer( )
X = data. data
y = data. target
X_train, X_test, y_train, y_test = train_test_split( X, y,
test_size= 0.3 , random_state= 420 )
train_score_list = [ ]
test_score_list = [ ]
score = [ ]
var = [ ]
for k in range ( 1 , 21 ) :
knn = KNeighborsClassifier( n_neighbors= k)
knn. fit( X_train, y_train)
train_score_list. append( knn. score( X_train, y_train) )
test_score_list. append( knn. score( X_test, y_test) )
cross = cross_val_score( knn, X_train, y_train, cv= 5 )
score. append( cross. mean( ) )
var. append( cross. var( ) )
print ( score)
print ( var)
[0.9220886075949366, 0.9270569620253164, 0.9320569620253165, 0.9295886075949367, 0.934620253164557, 0.929620253164557, 0.9395886075949367, 0.9396202531645569, 0.934620253164557, 0.9396518987341773, 0.9270886075949367, 0.9371202531645568, 0.9270253164556962, 0.9295569620253165, 0.9270569620253164, 0.9270569620253164, 0.9245569620253165, 0.9245569620253165, 0.9245569620253165, 0.9245569620253165]
[0.00021623738182983474, 0.0004824547348181385, 0.0007446082358596371, 0.00030059285370934077, 0.00022031124819740403, 0.0002915137798429741, 0.0004181361160070506, 0.0002241087165518342, 9.531124819740365e-05, 0.00021917160711424402, 0.00022353589168402476, 0.0004488022752763978, 0.0006115946963627635, 0.000622169924691557, 0.0006699547348181394, 0.0006699547348181394, 0.0005177395449447208, 0.0005177395449447208, 0.0005177395449447208, 0.0005177395449447208]
plt. figure( dpi= 200 )
plt. plot( range ( 1 , 21 ) , train_score_list, label= 'train_score' ) ;
plt. plot( range ( 1 , 21 ) , test_score_list, label= 'test_score' ) ;
plt. plot( range ( 1 , 21 ) , score, color= 'g' , label= 'cross_score' ) ;
plt. plot( range ( 1 , 21 ) , np. array( score) + np. array( var) * 2 , c= 'red' , linestyle= '--' )
plt. plot( range ( 1 , 21 ) , np. array( score) - np. array( var) * 2 , c= 'red' , linestyle= '--' )
plt. xlabel( 'k_value' )
plt. ylabel( 'socre' )
plt. legend( ) ;
np. argmax( score) + 1
max ( score)
0.9396518987341773