利用Scikit-learm实现手写数字识别
Scikit-learm库种自带一个手写数字数据集,其中包含1797个手写数值样本,每个样本由8×8的二维数组,数组元素为0-16之间的整数,每个样本都有对应的标签,标签为0-9之间的整数。可以把手写数值识别看成是10分类。
代码:
from sklearn import datasets
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.neural_network import MLPClassifier
from sklearn.metrics import precision_score
import numpy as np
def draw():#画出手写数据集的前36张图
for i in range(36):
plt.subplot(6,6,i+1)
plt.imshow(digits.images[i])
plt.show()
pass
digits=datasets.load_digits() #调用Scikit库中的手写数据集
x=digits.data #定义数据集中的数据,为[1797,64]的矩阵
y=digits.target #定义数据集中的标签, 为[1797,]的矩阵
x_train,x_test,y_train,y_test=train_test_split(x,y,test_size=0.2)#数据集和测试集比例为8:2
model=MLPClassifier() #调用多层感知机网络,激活函数默认为relu,隐藏层为100个神经元
model.fit(x_train,y_train) #输入训练集训练
y_predict=model.predict(x_test) #输入测试集得到预测结果
error=np.nonzero(y_test - y_predict) #测试集实际结果与预期结果进行比较
p_1=1 - len(error[0]) / len(y_test) #求得总的识别率
p_2=precision_score(y_test,y_predict,average=None) #当average=None时求得是各个类别的识别率
p_3=precision_score(y_test,y_predict,average="micro") #总的识别率,跟p_1一样
print(p_2,p_3)
#print(digits.data.shape)
#print(digits.target.shape)
#print(digits.images.shape)
draw()