首先感谢博主倔强的小彬雅,本文使用的素材及部分代码来源其博文机器学习入门-用KNN实现手写数字图片识别(包含自己图片转化),需要下载素材的可以到其博文最后进行下载。
关于KNN算法
knn算法也叫K临近算法
简单举个例子,如上图所示,坐标轴内随机分布这红色和绿色两种属性的图形,现在新加入了一个点,怎么来判断这个点可能是红色还是绿色呢?
我们取一个值K1=1,发现在离新加入这个点最近的K1个点是红色的,红色的点多于绿色的点,那么新加入的点很可能是红色的。
同样,取一个值K2=5,发现在离新加入这个点最近的K2个点中有2个是红色的,3个是绿色的,红色的点少于绿色的点,那么新加入的点很可能是绿色的。
以上就是KNN算法的原理,也可以看出,k值的取值非常重要,实验中可以调整K值的取值来改善识别的准确度。
KNN实现手写数字识别原理
实验前需要准备训练数据集和实验数据集。
要实现手写数字图片的识别,关在在于图片的向量化。而在图片向量化之前,先要处理图片,因为一张图片可能色彩比较复杂,大小也不一致,要对其进行降噪,避免干扰。另外在处理过程中将图片大小统一压缩成了32*32像素的灰度图像。
图片处理后, 图片为32*32像素,遍历图片中的每一个像素,计算其灰度值,这里采用灰度计算公式Gray = R*0.299 + G*0.587 + B*0.114,如果这个像素点的灰度值计算出来为255(即白色)则标记为1,其余点统一标记为0,将得到的01字符串保存在txt文件中,作为计算数据。
因为图片为32*32像素,故有1024个像素点,txt文件中同样有32行32列共1024个数值,我们把这1024个数值转换成一个1行1024列的矩阵向量,并计算其与各个训练数据的距离。(训练数据同样为0、1组成的txt文件,也转成向量)。距离采用欧氏距离公示进行计算。
最后来看,与测试数据最近的K个训练数据中,哪个数字的数量是最多的,那么就判断这个手写数字是几。
代码实现
图片处理部分
from PIL import Image
import matplotlib.pylab as plt
import numpy as np
import os.path
def pic_change(filename):
img=Image.open('E:/test/pic/'+filename)#我的文件路径为E:/test/pic/下,可修改
raw_data=img.load()#加载图片
#降噪部分
for x in range(img.size[0]):#x*y即像素值,x为行,y为列,遍历每个像素进行降噪
for y in range(img.size[1]):
if raw_data[x, y][0] < 90:#遍历像素,png图片每个像素有RGBA四个值,A值指透明度,透明度统一设为255
raw_data[x, y] = (0, 0, 0, 255)
for x in range(img.size[1]):
for y in range(img.size[0]):
if raw_data[x, y][1] < 136:
raw_data[x, y] = (0, 0, 0, 255)
for x in range(img.size[1]):
for y in range(img.size[0]):
if raw_data[x, y][2] > 0:
raw_data[x, y] = (255, 255, 255, 255)
resize_img=img.resize((32,32),Image.LANCZOS)#修改图片大小为32*32像素
resize_img.save('E:/test/new_'+filename)#存储新图片,可删
array=plt.array(resize_img)
gray_array = np.zeros((32, 32))
for x in range(array.shape[0]):
for y in range(array.shape[1]):
Gray = array[x,y][0] * 0.299 + array[x,y][1] * 0.587 + array[x,y][2] * 0.114#灰度公式计算每个像素点,将图片的每个像素转为0或1
if Gray==255:
gray_array[x,y]=0
else:
gray_array[x,y]=1
testFileList = os.listdir('E:/test/testDigits')
mTest = len(testFileList)
num=0
for i in range(mTest):#将转换好的图片存储在E:/test/testDigits文件夹下,这一部分为防止重名,对文件名中序号进行判断
testName=testFileList[i].split('_')[0]
if testName==filename.split('.')[0]:
num=num+1
new_txt_name = filename.split('.')[0] + '_' + str(num+1) + '.txt'#文件名格式为8_1.txt,8值这个数字的真实值,1为序号
np.savetxt('E:/test/testDigits/' + new_txt_name, gray_array, fmt='%d', delimiter='')
print(new_txt_name)
return new_txt_name
pic_list=os.listdir('E:/test/pic')#未处理的图片放在E:/test/pic文件夹下,处理完成后放在E:/test/testDigits文件夹下
n=len(pic_list)#遍历,对每张图片进行处理
for k in range(n):
print(pic_list[k])
pic_change(pic_list[k])
KNN算法执行部分
import numpy as np
import os.path
def img32_to_1024(filename):
returnVec=np.zeros((1,1024))#生成1行,1024列的矩阵向量,便于计算
file=open('E:/test/'+filename,'r')
linestr = file.readlines()
for i in range(32):
for n in range(32):
returnVec[0,i*32+n]=int(linestr[i][n])
return(returnVec)
def test(index,k):#KNN计算部分函数,传入测试数据向量和k值
dataSetSize = m
diffMat=index-trainingMat#以下四行即欧式公式计算距离
diffMat1=diffMat**2
diffMat2=diffMat1.sum(axis=1)
distances=diffMat2**0.5
sortedDistances = distances.argsort()#计算结果排序
classCount = {}
for i in range(k):#计算K个数据中哪个数字是最多的
voteIlabel =handWriteLabels[sortedDistances[i]]
classCount[voteIlabel] = classCount.get(voteIlabel, 0) + 1
sortedClassCount = sorted(classCount.items(), reverse=True)
return sortedClassCount[0][0]#返回计算出的手写数字值
#训练数据集处理
handWriteLabels=[]
trainingFileList=os.listdir('E:/test/trainingDigits')
m=len(trainingFileList)
trainingMat=np.zeros((m,1024))
for i in range(m):
fileNameStr = trainingFileList[i]
fileStr=fileNameStr.split('.')[0]
classNumStr=int(fileStr.split('_')[0])
handWriteLabels.append(classNumStr)
trainingMat[i, :]=img32_to_1024('trainingDigits/' + fileNameStr)
#测试数据集处理
testFileList=os.listdir('E:/test/testDigits')
errorCount=0.0#用于计算错误率
mTest=len(testFileList)
realLabel=[]
testLabel=[]
for i in range(mTest):
fileNameStr=testFileList[i]
fileStr=fileNameStr.split('.')[0]
classNumStr=int(fileStr.split('_')[0])
vectorUnderTest=img32_to_1024('testDigits/' + fileNameStr)
testLabel=test(vectorUnderTest,1)#1为k值,可修改为2、3、4……
print('识别出的数字为:'+str(testLabel)+'真实数字为:'+ str(classNumStr))
if(testLabel!=classNumStr):
errorCount+=1.0
print('错误率'+str(errorCount/mTest))
运行结果
k值为2时
K值取2时错误率最低