KNN对手写数字进行识别超参数网格优化等

该博客介绍了使用KNN算法对手写数字进行识别,通过调整超参数P和权重,优化模型性能。首先展示基本的KNN识别过程,然后探讨超参数P对模型准确性的影响,最后运用网格搜索方法进行参数调优,得出最佳参数组合。
摘要由CSDN通过智能技术生成

1、使用KNN对手写数字进行识别

# !/usr/bin/env python3
import matplotlib
from sklearn import datasets
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import accuracy_score
import matplotlib.pyplot as plt
import numpy as np

data = datasets.load_digits()
x = data.data
y = data.target
'''将位图显示在plt中'''
# sample = x[667].reshape(8,8)
# plt.imshow(sample,cmap='binary')
# plt.show()
'''转成列矩阵'''
h = y.reshape(-1,1)
'''将原本的特征值和标签合并方便后面shuffle后对应关系混乱'''
data = np.concatenate((x,h),axis=1)
'''打乱矩阵'''
np.random.shuffle(data)
'''从第1行到长度的0.8倍行防止存在小数强转整数的问题'''
train_data = data[:int(len(data)*0.8)]
'''从长度的0.8倍行到最末尾防止存在小数强转整数的问题'''
test_data = data[int(len(data)*0.8):]
'''截取矩阵所有行,从第0列到第64列 不包含第64列'''
X_train = train_data[:,0:64]
'''截取矩阵所有行,从第64列到第65列 不包含第65列'''
Y_train = train_data[:,64]
X_test = test_data[:,0:64]
Y_test = test_data[:,64]
'''创建KNN分类器'''
kNN_classifier = KNeighborsClassifier(n_neighbors=6)
'''训练'''
kNN_classifier.fit(X_train,Y_train)
'''预测'''
Y_predict = kNN_classifier.predict(X_test)
'''accuracy_score对模型打分'''
print(accuracy_score(Y_test,Y_predict))

执行结果:
k= 1 accuracy_score= 0.9888888888888889 weights= uniform

2、超参数P推到
在这里插入图片描述
注意只对权重weight = ‘distance’ 有效

# !/usr/bin/env python3
import matplotlib
from sklearn import datasets
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import accuracy_score
import matplotlib.pyplot as plt
import numpy as np

data = datasets.load_digits()
x = data.data
y = data.target
'''将位图显示在plt中'''
# sample = x[667].reshape(8,8)
# plt.imshow(sample,cmap='binary')
# plt.show()
'''转成列矩阵'''
h = y.reshape(-1,1)
'''将原本的特征值和标签合并方便后面shuffle后对应关系混乱'''
data = np.concatenate((x,h),axis=1)
'''打乱矩阵 使用随机种子保障每次实验的数据相等'''
np.random.seed(666)
np.random.shuffle(data)
'''从第1行到长度的0.8倍行防止存在小数强转整数的问题'''
train_data = data[:int(len(data)*0.8)]
'''从长度的0.8倍行到最末尾防止存在小数强转整数的问题'''
test_data = data[int(len(data)*0.8):]
'''截取矩阵所有行,从第0列到第64列 不包含第64列'''
X_train = train_data[:,0:64]
'''截取矩阵所有行,从第64列到第65列 不包含第65列'''
Y_train = train_data[:,64]
X_test = test_data[:,0:64]
Y_test = test_data[:,64]
'''创建KNN分类器   weights= distance 表示将距离的倒数权重纳入考虑范围'''
bast_k = 0
bast_score = 0
bast_p = 0
for pv in range(1,6):
    for i in range(1,11):
            '''
            n_jobs int,默认=无
            为邻居搜索运行的并行作业数。 None除非在
            joblib.parallel_backend上下文中,否则表示 1。 -1意味着使用所有处理器
            P:超参数推到 当为1明科夫斯基距离 当为2明科夫斯基距离的平方就是欧拉距离
            '''
            kNN_classifier = KNeighborsClassifier(n_neighbors=i,weights='distance',n_jobs=-1,p=pv)
            '''训练'''
            kNN_classifier.fit(X_train,Y_train)
            '''预测'''
            Y_predict = kNN_classifier.predict(X_test)
            '''accuracy_score对模型打分'''
            score = accuracy_score(Y_test,Y_predict)
            if score > bast_score:
                bast_score = score
                bast_k = i
                bast_p = pv
print('k=',bast_k,'accuracy_score=',bast_score,'p=',bast_p)


执行结果:k= 2 accuracy_score= 0.9916666666666667 p= 1

3、API提供的网格搜索方法

# !/usr/bin/env python3
import matplotlib
from sklearn import datasets
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import accuracy_score
import matplotlib.pyplot as plt
import numpy as np

data = datasets.load_digits()
x = data.data
y = data.target
'''将位图显示在plt中'''
# sample = x[667].reshape(8,8)
# plt.imshow(sample,cmap='binary')
# plt.show()
'''转成列矩阵'''
h = y.reshape(-1,1)
'''将原本的特征值和标签合并方便后面shuffle后对应关系混乱'''
data = np.concatenate((x,h),axis=1)
'''打乱矩阵 使用随机种子保障每次实验的数据相等'''
np.random.seed(666)
np.random.shuffle(data)
'''从第1行到长度的0.8倍行防止存在小数强转整数的问题'''
train_data = data[:int(len(data)*0.8)]
'''从长度的0.8倍行到最末尾防止存在小数强转整数的问题'''
test_data = data[int(len(data)*0.8):]
'''截取矩阵所有行,从第0列到第64列 不包含第64列'''
X_train = train_data[:,0:64]
'''截取矩阵所有行,从第64列到第65列 不包含第65列'''
Y_train = train_data[:,64]
X_test = test_data[:,0:64]
Y_test = test_data[:,64]

'''网格扫描'''
param_grid = [
    {
     'weights':['distance'],
     'n_neighbors':[i for i in range(1,11)],
     'p':[i for i in range(1,6)]
    },
    {
     'weights':['uniform'],
     'n_neighbors':[i for i in range(1,11)]
    }
]
'''默认对象'''
kNN_classifier = KNeighborsClassifier()
'''创建网格搜索对象'''
from sklearn.model_selection import GridSearchCV
'''verbose 打印出每次迭代的结果 n_jobs = -1表示使用所有cpu cv = 3表示进行3次交叉验证'''
grid_search = GridSearchCV(kNN_classifier,param_grid,n_jobs=-1,verbose=2)
'''训练'''
grid_search.fit(X_train,Y_train)
'''匹配的最好的参数'''
print(grid_search.best_params_)
'''最好的准确度'''
print(grid_search.best_score_)
'''最佳分类器'''
print(grid_search.best_estimator_)
'''获取到最佳分类器'''
estimator = grid_search.best_estimator_
'''预测'''
Y_predict = estimator.predict(X_test)
'''accuracy_score对模型打分'''
score = accuracy_score(Y_test,Y_predict)
print(score)
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

sunnyboy_4

你的鼓励是我创作的动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值