李航《统计学习方法》K近邻代码
import numpy as np
import pandas as pd
import time
def loadData(filename):
print('start to load data')
dataArr = []
labelArr = []
fr = open(filename,'r')
for line in fr.readlines():
curLine = line.strip().split(',')
dataArr.append([int(num)/255 for num in curLine[1:]])
labelArr.append(int(curLine[0]))
return dataArr,labelArr
def calcDist(x1,x2):
return np.sqrt(np.sum(np.square(x1 - x2)))
def getClosest(trainDataMat, trainLabelMat, x, topK):
distList = [0] * len(trainDataMat)
for i in range(len(trainDataMat)):
xi = trainDataMat[i]
curDist = calcDist(xi,x)
distList[i] = curDist
topKList = np.argsort(np.array(distList))[:topK]
labelList = [0] * 10
for index in topKList:
labelList[int(trainLabelMat[index])] += 1
return labelList.index(max(labelList))
def model_test(trainDataArr,trainLabelArr,testDataArr,testLabelArr,topK):
print('start to test')
trainDataMat = np.mat(trainDataArr); trainLabelMat = np.mat(trainLabelArr).T
testDataMat = np.mat(testDataArr); testlabelMat = np.mat(testLabelArr).T
errorCnt = 0
for i in range(200):
print('test %d:%d' %(i+1,200))
x = testDataMat[i]
y = getClosest(trainDataMat,trainLabelMat,x,topK)
if y != testlabelMat[i]:
errorCnt += 1
return 1 - (errorCnt / 200)
if __name__ == "__main__":
start = time.time()
trainDataArr,trainLabelArr = loadData('./mnist_train.csv')
testDataArr,testLabelArr = loadData('./mnist_test.csv')
accur = model_test(trainDataArr,trainLabelArr,testDataArr,testLabelArr,25)
end = time.time()
print('the accuracy rate is:',accur)
print('the run time is:',end - start)