python 3.6
此代码取自书中。
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
@file:KNN.py
@time:2019/1/1 14:10
@author:Victor
@site:https://blog.csdn.net/sumup
"""
from numpy import *
import operator
import matplotlib
import matplotlib.pyplot as plt
# def createDataSet():
# group = array([[1.0,1.1],[1.0,1.0],[0,0],[0,0.1]])
# labels = ['A','A','B','B']
# return group, labels
def classify0(inX, dataSet, labels, k):
dataSetSize = dataSet.shape[0]
##输出dataSet的行数
print(dataSet.shape[1])##2列
print(dataSet.shape[0])##4行
##构造(dataSetSize,1)规摸的数组,并inX数组与dataSet相减
# [[-1. -1.1]
# [-1. -1. ]
# [ 0. 0. ]
# [ 0. -0.1]]
diffMat = tile(inX, (dataSetSize,1)) - dataSet
print(diffMat)
##矩阵平方
# [[1. 1.21]
# [1. 1. ]
# [0. 0. ]
# [0. 0.01]]
sqDiffMat = diffMat**2
print(sqDiffMat)
##矩阵每行求和
# [2.21 2. 0. 0.01]
sqDistances = sqDiffMat.sum(axis=1)
print(sqDistances)
# print(sqDistances.shape[0]) ##4
##每行的和开方
# [1.48660687 1.41421356 0. 0.1 ]
distances = sqDistances**0.5
print(distances)
##从小到大排序,输出下标
# [2 3 1 0]
sortedDistIndicies = distances.argsort()
print(sortedDistIndicies)
classCount = {}
for i in range(k):
voteIlabel = labels[sortedDistIndicies[i]]
print(voteIlabel)
##在classCount中查找voteIlabel,若没,赋值0,并+1
classCount[voteIlabel] = classCount.get(voteIlabel,0) + 1
# print(classCount.items()) ##dict_items([('B', 2), ('A', 1)])
##key 为classCount的value
sortedClassCount = sorted(classCount.items(),key=operator.itemgetter(1), reverse=True)
return sortedClassCount[0][0]
##Input filename
##Output returnMat classLabelVector
def file2matrix(filename):
fr = open(filename)
arrayOLines = fr.readlines()
numberOfLines = len(arrayOLines)
##创建(numberOfLines,3)矩阵,并初始化0
returnMat = zeros((numberOfLines,3))
classLabelVector = []
index = 0
for line in arrayOLines:
##默认删除空白符(包括'\n', '\r', '\t', ' ')
line = line.strip()
##根据‘\t’划分开
listFromLine = line.split('\t')
returnMat[index,:] = listFromLine[0:3]
# print(returnMat[index,:])
classLabelVector.append(int(listFromLine[-1]))
index += 1
return returnMat,classLabelVector
##newvalue = (oldValue-min)/(max-min)
def autoNorm(dataSet):
##返回矩阵中每一列的最小值
minVals = dataSet.min(0)
##返回矩阵中每一列的最大值
maxVals = dataSet.max(0)
ranges = maxVals - minVals
normDataSet = zeros(shape(dataSet))
m = dataSet.shape[0]
normDataSet = dataSet - tile(minVals, (m,1))
normDataSet = normDataSet/tile(ranges, (m,1))
return normDataSet, ranges, minVals
def datingClassTest():
#10%用于求结果
hoRatio = 0.10
datingDataMat,datingLabels = file2matrix('datingTestSet2.txt')
normMat,ranges,minVals = autoNorm(datingDataMat)
m = normMat.shape[0]
numTestVecs = int(m*hoRatio)
errorCount = 0.0
for i in range(numTestVecs):
classifierResult = classify0(normMat[i,:],normMat[numTestVecs:m,:],datingLabels[numTestVecs:m],3)
print('the classfier came back with: %d, the real answer is: %d' %(classifierResult,datingLabels[i]))
if (classifierResult != datingLabels[i]):
errorCount += 1.0
print('the total error rate is: %f' %(errorCount/float(numTestVecs)))
def classifyPerson():
resultList =['not at all','in small does','in large does']
percentTats = float(input("percentage of time spent playing video games?"))
ffMiles = float(input("frequent flier miles earned per year"))
iceCream = float(input("liters of iceam consumed per year?"))
datingDataMat,datingLabels = file2matrix('datingTestSet2.txt')
normMat, ranges, minVals = autoNorm(datingDataMat)
inArr = array([ffMiles, percentTats,iceCream])
classifierResult = classify0((inArr-minVals)/ranges,normMat,datingLabels,3)
print("you will probably link this person: ",resultList[classifierResult-1])
# group,labels = createDataSet()
# print(group)
# print(labels)
# print(classify0([0,0],group,labels,3))
# datingDataMat,datingLabels = file2matrix('datingTestSet2.txt')
# fig = plt.figure()
#add_subplot(mnp)#添加子轴、图。subplot(m,n,p)或者subplot(mnp)此函数最常用:
#subplot是将多个图画到一个平面上的工具。其中,m表示是图排成m行,n表示图排成n列,也就是整个figure中有n个图是排成一行的,一共m行,如果第一个数字是2就是表示2行图。
#p是指你现在要把曲线画到figure中哪个图上,最后一个如果是1表示是从左到右第一个位置。
# ax = fig.add_subplot(111)
# ax.scatter(datingDataMat[:,0],datingDataMat[:,1],15.0*array(datingLabels),15.0*array(datingLabels))
# print(autoNorm(datingDataMat))
# plt.show()
# datingClassTest()
classifyPerson()