__author__ = '糖衣豆豆'
from numpy import *
from os importlistdirimportoperator#从列方向扩展#tile(a,(size,1))#实现KNN算法,需要指定k,需要测试数据集,需要训练数据集,类别名(标签),
defknn(k,testdata,traindata,labels):#通过shape获得行数
traindatasize=traindata.shape[0]#扩展testdata的维数,tile函数可以扩展testdata和traindata相同的行数,然后和traindata的向量相减计算测试机和训练集的差值
dif=tile(testdata,(traindatasize,1))-traindata#计算差值的平方
sqdif=dif**2
#计算平方和,每一行的各列求和,axis=1每一行的各列求和
sumsqdif=sqdif.sum(axis=1)#开方
distance=sumsqdif**0.5
#排序
sortdistance=distance.argsort()#空字典
count={}#选择距离最短的k
for i inrange(0,k):#获取类别,下标决定属于哪一类
vote=labels[sortdistance[i]]#整理为一定格式,得到类别vote,每出现一次统计一次
count[vote]=count.get(vote,0)+1
#取出最多的类别,reverse=True表示降序
sortcount=sorted(count.items(),key=operator.itemgetter(1),reverse=True)returnsortcount[0][0]#图片处理#先将图片转为固定宽高,比如32*32,然后再转为文本
'''from PIL import Image
im=Image.open("~/Downloads/123.png")
fh=open("~/Downloads/123_txt","a")
width=im.size[0]
height=im.size[1]
#k=im.getpixel((1,9))
#print(k)
for i in range(0,width):
for j in range(0,height):
cl=im.getpixel((i,j))
clall=cl[0]+cl[1]+cl[2]
if(clall==0):
#黑色
fh.write("1")
else:
fh.write("0")
fh.write("\n")
fh.close()'''
#加载数据#将数据转为数组
defdatatoarray(fname):
arr=[]
fh=open(fname)#图片是32*32的横轴每次读取32
for i in range(0,32):
thisline=fh.readline()#读每一行
for j in range(0,32):#读入到数组里
arr.append(int(thisline[j]))returnarr
arr1=datatoarray("~/coding/python/data/testandtraindata/testdata/0_74.txt")#print(arr1)#建立一个函数,取文件的前缀
defseplabel(fname):
filestr=fname.split(".")[0]
label=int(filestr.split("_")[0])returnlabel#建立训练数据
deftraindata():#存储类别
labels=[]#得到训练目录下所有的文件
trainfile=listdir("~/coding/python/data/testandtraindata/traindata")#取当前文件有多少个
num=len(trainfile)#生成一个多少行多少列的向量,行的长度应该是32*32=1024(列),每一行存储一个文件
#用一个数组存储所有训练数据,行:文件总数,列:1024
trainarr=zeros((num,1024))#第一层循环文件
for i inrange(0,num):
thisfname=trainfile[i]#调用seplabel函数
thislabel=seplabel(thisfname)#存到数组里
labels.append(thislabel)#调用datatoarray函数,i,:处理重复读取
trainarr[i,:]=datatoarray("~/coding/python/data/testandtraindata/traindata/"+thisfname)returntrainarr,labels#用测试数据条用KNN算法去测试,看是否能够准确识别
defdatatest():
trainarr,labels=traindata()
testlist=listdir("~/coding/python/data/testandtraindata/testdata")
tnum=len(testlist)for i inrange(0,tnum):
thistestfile=testlist[i]
testarr=datatoarray("~/coding/python/data/testandtraindata/testdata/"+thistestfile)
rknn=knn(3,testarr,trainarr,labels)print(rknn)#a=datatest()#print(a)#抽某一个文件测试文件出来进行验证
trainarr,labels=traindata()
thistestfile="8_15.txt"testarr=datatoarray("~/coding/python/data/testandtraindata/testdata/"+thistestfile)
rknn=knn(3,testarr,trainarr,labels)print(rknn)