数据集是自己下载的mnist的手写识别的数据,有一个train.csv文本,一个test.csv测试文本,还有一个submission.csv文本(存放的是test.csv的标签),不多说了,KNN原理很简单,直接上代码吧
#autor:zhouchao
#date:2017-12-07 11:13
#description:use knn to recognize num
import numpy as np
from numpy import *
import operator
from numpy import random
def load_train_data(path):
train=np.loadtxt(path,delimiter=",", skiprows=0)
vec=train[:,1:]
labels=train[:,0:1].tolist()
print type(labels)
return vec,labels
def predict(line,vec,labels):
numSamples = vec.shape[0]
diff = tile(line, (numSamples, 1)) - vec
squaredDiff = diff ** 2
squaredDist = sum(squaredDiff, axis = 1)
distance = squaredDist ** 0.5
sortedDistIndices = argsort(distance)
classCount = {}
for i in xrange(20):
voteLabel = labels[sortedDistIndices[i]][0]
classCount[voteLabel] = classCount.get(voteLabel, 0) + 1
maxCount = 0
for key, value in classCount.items():
if value > maxCount:
maxCount = value
maxIndex = key
return maxIndex
if __name__=="__main__":
vec,labels=load_train_data("../../data/handwrite/train.csv")
f=open("../../data/handwrite/test.txt")
for line in f.readlines():
nums = line.split(",")
nums = [int(x) for x in nums ]
matrix = np.array(nums)
print predict(matrix,vec,labels)