1、一般流程
1)收集数据:可以使用任何方法
2)准备数据:距离计算所需要的数值,最好是结构化的数据格式
3)分析数据:可以使用任何方法
4)训练算法: 此步骤不适用于k-近邻
5)测试算法:计算错误率
6)使用算法:首先需要输入样本数据和结构化的输出结果,然后运用k-近邻算法
判定输入数据分别属于哪个分类,最后应用对计算出的分类执行后续的处理
2、kNN的优缺点
- 优点:精度高、对异常值不敏感、无数据输入假定
- 缺点:计算复杂度高、空间复杂度高
- 适用范围:数值型和标称型
3、k-近邻算法实现
# python 3.6
# 20180418
# **所有需要加载的包:**
from numpy import *
import operator
import tqdm #查看循环进度
import os # 为了批量打开
# 为了测试用
def CreateDataSet():
group = array([[1.0,1.1],[1.0,1.0],[0,0],[0,0.1]])
labels = ['A','A','B','B']
return group, labels
group, labels = CreateDataSet()
# k-近邻算法
def classify0(inX, dataSet, labels, k):
"""
计算inx 为什么类型
:param inX: 输入向量(需要确定分类的值)
:param dataSet: 训练集样本 array
:param labels: 标签向量 List
:param k: 选择最近邻居的数目
:return: 标签,即种类
"""
# 计算距离
dataSetSize = dataSet.shape[0]
diffMat = tile(inX, (dataSetSize,1)) - dataSet
sqDiffMatt = diffMat**2
sqDistances = sqDiffMatt.sum(axis=1)
distance = sqDistances ** 0.5
# 元素从小到大排列,提取其对应的index(索引)
sortDistIndicies = distance.argsort()
# 选择距离最小的k个点
classCount = {}
for i in range(k):
voteIlabel = labels[sortDistIndicies[i]]
classCount[voteIlabel] = classCount.get(voteIlabel,0) + 1
sortedClassCount = sorted(classCount.items(),key = \
operator.itemgetter(1),reverse=True)
return sortedClassCount[0][0]
classify0([0,0],group, labels ,1)
结果:
Out[18]: 'B'
4、在约会网站上使用K-近邻算法
1)收集数据:提供文本文件
2)准备数据:使用Python解析文本文件
3)分析数据:使用Matplotlib画二维扩散图
4)训练算法:此步骤不适用于K-近邻
5)测试算法:使用海伦提供的部分数据作为测试样本测试样本和非测试样本的区别在于:测试样本是已经完成分类的数据,如果预测分类与实际不同,标记为一个错误
6)使用算法:产生简单的命令行程序,然后海伦可以输入一些特征数据以判断对方是否为自己喜欢类型
使用Python解析文本文件:
def file2matrix(filename):
"""
将文本记录转换为Numpy的解析程序
:param filename: 文件路径
:return: 训练集样本 array; 标签向量 List
"""
fr = open(filename)
arrayOLines = fr.readlines()
numberOfLines = len(arrayOLines)
returnMat = zeros((numberOfLines,3))
classLabelVector = []
index = 0
for line in tqdm.tqdm(arrayOLines):
line = line.strip()
line = line.split("\t")
returnMat[index,:] = line[0:3]
classLabelVector.append(int(line[-1]))
index += 1
fr.close()
return returnMat, classLabelVector
# 测试
fil = r"E:\python Data\datingTestSet2.txt"
datingDataMAT,likeornot = file2matrix(fil)
消除量纲影响:
def autoNorm(dataset):
"""
将每个字段数据归一化
:param dataset: 训练集样本 array
:return: 归一化后的数据集 array;
"""
minv = dataset.min(0)
maxv = dataset.max(0)
ranges = maxv - minv
m = dataset.shape[0]
normDateset = dataset - tile(minv,(m,1))
normDateset = normDateset/tile(ranges,(m,1)) # 特征值相除
return normDateset, ranges, minv
normDateset, ranges, minv =autoNorm(datingDataMAT)
分类器针对约会网站的测试代码:
def datingClassTest():
hoRating = 0.10
datingDataMAT, likeornot = file2matrix(fil)
normDateset, ranges, minv = autoNorm(datingDataMAT)
m = normDateset.shape[0]
numTestVecs = int(m*hoRating)
errorCount = 0
for i in tqdm.tqdm(range(numTestVecs)):
# 测试前100, 后900作为训练集
r = classify0(normDateset[i,:],
normDateset[numTestVecs:m,:],
likeornot[numTestVecs:m],3)
if (r != likeornot[i]):
errorCount += 1
print("\nthe total error rate is: {}".format(errorCount/float(numTestVecs)))
datingClassTest()
5、手写识别系统
1)收集数据:提供文本文件
2)准备数据:编写函数,将图像格式转换为分类器使用的向量格式
3)分析数据:在Python命令提示中检查数据,确保它符合要求
4)训练算法:此步骤不适用于K-近邻
5)训练算法:编写函数使用提供的部分数据集作为测试样本,测试样本与非测试样本的区别在于测试样本是已经完成分类的数据,如果预测分类与实际类别不同,则标记为一个错误
6)使用算法:本例未完成该步骤
将图像格式转换为分类器使用的向量格式:
def img2vector(filname):
"""
循环读出文件的前32行,并将每行的32个字符值存储在Numpy数组中
:param filname:
:return:
"""
returnVect = zeros((1,1024))
fr =open(filname)
for i in range(32):
lineStr = fr.readline()
for j in range(32):
returnVect[0,32*i + j] = int(lineStr[j])
fr.close()
return returnVect
实现识别及计算错误率:
rootdir = "E:/python Data/Ch02/trainingDigits/"
rootdirt = "E:/python Data/Ch02/testDigits/"
def handwritingClassTest():
"""
遍历训练文件生成测试矩阵
遍历测试文件进行预测
计算错误率
:return: 预测错误率
"""
Labels = []
trainingMAT = 0
for root,dirs,files in os.walk(rootdir):
m = len(files)
trainingMAT = zeros((m, 1024))
n = 0
for i in tqdm.tqdm(files):
lab = i.split("_")[0] # 文件的首个数字为该文件的具体数字
trainingMAT[n,:] = img2vector(rootdir + i)
Labels.append(lab)
n += 1
errorCounts = 0.0
l = 0
for roott,dirst,filest in os.walk(rootdirt):
l = len(filest)
for it in tqdm.tqdm(filest):
lab1 = it.split("_")[0]
rVect = img2vector(rootdirt + it)
r = classify0(rVect, trainingMAT, Labels, 3)
if (r != lab1):
errorCounts += 1
print("\nthe total error rate is: {}".format(errorCounts/l))
handwritingClassTest() # 运行时间 1 + 28s
实际上使用这个算法时,算法的执行效率并不高。k决策树就是k-近邻算法的优化版,可以大大节省计算的开销。
参考:《机器学习实战》
数据:链接:https://pan.baidu.com/s/1dkb0W0xBR3csj1UhPmxq7w 密码:najh