K-近邻算法采用测量不同特征值之间的距离的方法进行分类
- 优点:精度高,对异常值不敏感,无数据输入假定
- 缺点:计算复杂度高,空间复杂度高
- 适用范围:数值型和标称型
算法执行描述:
对未知类别属性的数据集中的每个点执行以下操作
1.计算一致类别数据集中的点与当前点之间的距离
2.按距离递增次序排序
3.选取与当前点距离最小的K个点
4.确定前K个点出现频率最高的类别作为当前点的预测分类
from numpy import *
import operator
import matplotlib
import matplotlib.pyplot as plt
def createDataSet():
group = array([[1.0,1.1],[1.0,1.0],[0,0],[0,0.1]])
labels = ['A','A','B','B']
return group,labels
# inX : 分类的输入量
# dataSet : 训练集
# labels : 训练集对应的标签
# k: 最近邻数目
def classify0(inX,dataSet,labels,k):
dataSetSize = dataSet.shape[0] # 获取训练集有几行
diffMat = tile(inX,(dataSetSize,1)) - dataSet # 用tile函数重建数组 重复对象是inX,重复datasetsize行,每行重复1次;然后和dataSet做差
# 这里就是变成4行 x1-x2,y1-y2
sqDiffMat = diffMat**2 # 求平方(x1-x2)^2,(y1-y2)^2
sqDistances = sqDiffMat.sum(axis=1) # axis=1列与列相加 axis=0行与行相加 (x1-x2)^2+(y1-y2)^2
distances = sqDistances**0.5 # 开平方
sortedDistIndicies = distances.argsort() # 排序,输出排完序之后的索引后面 升序排序
classCount={}
for i in range(k):
voteIlabel = labels[sortedDistIndicies[i]]
classCount[voteIlabel] = classCount.get(voteIlabel,0) + 1 #查找排完序之后索引对应的标签,默认为0
sortedClassCount = sorted(classCount.items(),key=operator.itemgetter(1),reverse=True) # 根据标签的计数排序
return sortedClassCount[0][0] # 返回最大值
# 处理文件数据
def file2matrix(filename):
fr = open(filename) # 打开文件
arrayOlines = fr.readlines() #读取文件
numbersOfLines = len(arrayOlines) # 文件有多少行
returnMat = zeros((numbersOfLines,3)) # 创建0矩阵
classLabelVector = [] # 标签集合
index = 0
for line in arrayOlines:
line = line.strip()#移除字符串头尾的空格
listFromLine = line.split('\t')
returnMat[index,:] = listFromLine[0:3] # 取前三个数据然后给切片赋值
classLabelVector.append(int(listFromLine[-1])) # 最后一个是标签
index += 1
return returnMat,classLabelVector
# 归一化特征值
def autoNorm(dataSet):
minVals = dataSet.min(0)
maxVals = dataSet.max(0)
ranges = maxVals - minVals
m = dataSet.shape[0]
normDataSet = dataSet - tile(minVals,(m,1))
normDataSet = normDataSet/tile(ranges,(m,1))
return normDataSet,ranges,minVals
def test():
group,labels=createDataSet()
print(classify0([0,0],group,labels,3))
def test2():
filename = "datingTestSet2.txt"
datingDataMat,datingLabels=file2matrix(filename)
# print(datingDataMat)
# print(datingLabels[0:20])
normMat,ranges,minVals = autoNorm(datingDataMat)
print(normMat[:20])
# 画散点图
# fig = plt.figure()
# ax = fig.add_subplot(111)
# ax.scatter(datingDataMat[:,1],datingDataMat[:,2],15.0*array(datingLabels),15.0*array(datingLabels))
# plt.show()
def datingClassTest():
hoRatio = 0.1 # 测试样本的比例
datingDataMat,datingLabels = file2matrix('datingTestSet2.txt') # 载入数据
normMat,ranges,minVals = autoNorm(datingDataMat) # 归一化处理
m = normMat.shape[0]
numTestVecs = int(m*hoRatio) # 获取测试样本
errorCount = 0.0
for i in range(numTestVecs):
classifierResult = classify0(normMat[i,:],normMat[numTestVecs:m,:],datingLabels[numTestVecs:m],3)
print("the classifier came back with: %d, the real answer is: %d" % (classifierResult, datingLabels[i]))
if (classifierResult != datingLabels[i]): errorCount += 1.0
print("the total error rate is: %f" % (errorCount/float(numTestVecs)))
print(errorCount)
def
if __name__ == "__main__":
# test()
# test2()
datingClassTest()