import csv
import random
import math
import operator
#导入数据集,split将数据分为两部分,训练集和测试集
def loadDataset(filename,split,trainingSet=[],testSet=[]):
with open(filename,'rt') as csvfile:
lines=csv.reader(csvfile)#读取每一行到文件中
dataset=list(lines)#将文件转为一个list
for x in range(len(dataset)-1):
for y in range(4):
dataset[x][y]=float(dataset[x][y])
if random.random()<split:
trainingSet.append(dataset[x])
else:
testSet.append(dataset[x])
#计算两个实例之间的距离,length表示维度
def euclideanDistance(instance1,instance2,length):
distance=0
for x in range(length):
distance+=pow((instance1[x]-instance2[x]),2)
return math.sqrt(distance)#sqrt()开方函数
#从训练集中返回离测试yigeyangbenK个最近邻居
def getNeighbors(trainingSet,testInstance,k):
distances=[]
length=len(testInstance)-1#维度
for x in range(len(trainingSet)):
dist=euclideanDistance(testInstance,trainingSet[x],length)#对于训练集中的每一个X 算出和测试样本的距离
distances.append((trainingSet[x],dist))#将算出来的距离加入到容器当中
distances.sort(key=operator.itemgetter(1))#利用排序的方法将距离从小到大排序出来
neighbors=[]
for x in range(k):
neighbors.append(distances[x][0])
return neighbors#找到最近的k个邻居
#在k个最近的邻居中找到占个数最多的那个类
def getResponse(neighbors):
classVotes={}
for x in range(len(neighbors)):
response=neighbors[x][-1]
if response in classVotes:
classVotes[response]+=1
else:
classVotes[response]=1
#之所以会出现上述错误是因为python3中已经没有这个属性,直接改为items即可:
sortedVotes=sorted(classVotes.items(),key=operator.itemgetter(1),reverse=True)#对每个类占的个数排序按降序
return sortedVotes[0][0]#返回排在最前面的类
#计算精确度。当所有的测试集预测完后%
def getAccuracy(testSet,predictions):
correct=0
for x in range(len(testSet)):
if testSet[x][-1]==predictions[x]:
correct+=1
return (correct/float(len(testSet)))*100.0
def main():
trainingSet=[]
testSet=[]
split=0.67
loadDataset(r'/home/htt/htt/machineL/Knn/irisdata.txt',split,trainingSet,testSet)
print('train set:'+repr(len(trainingSet)))
print('test set:'+repr(len(testSet)))
predictions=[]
k=3
for x in range(len(testSet)):
neighbors=getNeighbors(trainingSet,testSet[x],k)
result=getResponse(neighbors)
predictions.append(result)
print('>predicted='+repr(result)+'actual:'+repr(testSet[x][-1]))
accuracy=getAccuracy(testSet,predictions)
print('Accuracy:'+repr(accuracy)+'%')
main()