机器学习实战第二章-k近邻算法(包含一些python绘图基础)

一,k近邻算法概述

k近邻算法是一种简单有效但并不高效的非线性分类方法。

  • 优点:精度高,对异常值不敏感、无数据输入假设。
  • 缺点:计算复杂度高、空间复杂度高。
  • 使用数据范围:离散型和连续型。

二,k近邻算法的核心步骤

对未知类别属性的数据集中的每一个点依次执行以下操作:
1. 计算已知数据集中的点与当前点之间的距离。
2. 按照距离递增次序排序。
3. 选取与当前点距离最小的k个点。
4. 确定前k个点所在类别的出现频率。
5. 返回前k个点出现频率最高的类别作为当前点的预测分类。

三,k近邻算法应用的一般流程

  1. 收集数据:可以使用任何方法。例如:存储到数据库(mysql、mongodb等)或者直接存储成文本文件。
  2. 准备数据:距离计算所需要的数值,最好是结构化的数据格式。
  3. 分析数据:可以使用任何方法。例如:用matplotlib画二维扩散图。
  4. 训练算法:此步骤不适用于k近邻算法,因为k近邻直接基于实例,无需训练。
  5. 测试算法:在测试集上计算错误率。
  6. 使用算法:首先输入样本数据和结构化的输出结果,然后运行k近邻算法判定输入数据属于哪一类别,最后应用对计算出的分类执行后续的处理。

四,k近邻算法应用的Python3代码实现

from numpy import *
import operator
import matplotlib
import matplotlib.pyplot as plt
# 1.收集数据省略,因为已给出数据。
# 4.训练算法省略,因为k近邻算法无需训练。

# k近邻算法核心步骤
def classify0(inX,dataSet,labels,k):
    dataSetSize = dataSet.shape[0]
    diffMat = tile(inX,(dataSetSize,1))-dataSet
    sqDiffMat = diffMat**2
    sqDistances = sum(sqDiffMat,axis=1)
    sortedDistIndicies = sqDistances.argsort()
    classCount = {}
    for i in range(k):
        voteILabel = labels[sortedDistIndicies[i]]
        classCount[voteILabel] = classCount.get(voteILabel,0)+1
    sortedClassCount = sorted(classCount.items(),key=operator.itemgetter(1),reverse=True)
    return sortedClassCount[0][0]
# 2.准备数据
def file2matrix(filename):
    fr = open(filename)
    fileLines = fr.readlines()
    numsOfLines = len(fileLines)
    returnMat = zeros((numsOfLines,3))
    index = 0
    classLabelVector = []
    for line in fileLines:
        listOfLine = line.strip().split('\t')
        returnMat[index,:] = listOfLine[:3]
        classLabelVector.append(int(listOfLine[-1]))
        index += 1
    return returnMat,classLabelVector
# 3.分析数据
def plotData(DataSet,Labels,i,j):
    fig = plt.figure()
    ax = fig.add_subplot(111)
    ax.scatter(DataSet[:,i],DataSet[:,j],15.0*array(Labels),15.0*array(Labels))
    plt.show()
# 对数据进行数值归一化
def autoNorm(dataSet):
    minVals = dataSet.min(0)
    maxVals = dataSet.max(0)
    normDataSet = zeros(dataSet.shape)
    n = dataSet.shape[0]
    normDataSet = dataSet - tile(minVals,(n,1))
    ranges = maxVals - minVals
    normDataSet = normDataSet/tile(ranges,(n,1))
    return normDataSet,ranges,minVals
# 5.测试算法
def datingClassTest():
    hoRatio = 0.1
    datingDataMat,datingLabels = file2matrix('datingTestSet2.txt')
    normDataSet,ranges,minVals = autoNorm(datingDataMat)
    n = normDataSet.shape[0]
    numTestSet = int(n*hoRatio)
    errorCount = 0
    for i in range(numTestSet):
        classifyResult = classify0(normDataSet[i,:],normDataSet[numTestSet:n,:],datingLabels[numTestSet:n],3)
        print('The classifier result is: %d,and the real answer is:%d。'%(classifyResult,datingLabels[i]))
        if classifyResult!=datingLabels[i]:
            errorCount += 1
    print('The total error rate is: %f'%(errorCount/numTestSet))
# 6.使用算法
def classifyPerson():
    resultList = ['not at all','is small doses','in large doses']
    Game = float(input('please input the percentage of time spent in playing video games:'))
    FlyMiles = float(input('please input the fly miles:'))
    iceCream = float(input('please input the liters of ice cream consumed per year:'))
    inArr = array([FlyMiles,Game,iceCream])
    dateSet,dateLabels = file2matrix('datingTestSet2.txt')
    normData,ranges,minVals = autoNorm(dateSet)
    result = classify0((inArr-minVals)/ranges,normData,dateLabels,3)
    print('You will probably like this person:',resultList[result-1])

附录-Python程序中用到的函数

1, numpy.tile(A,B)函数:,重复A,B次,这里的B可以时int类型也可以是元组类型

>>> import numpy  
>>> numpy.tile([0,0],5)#在列方向上重复[0,0]5次,默认行1次  
array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0])  
>>> numpy.tile([0,0],(1,1))#在列方向上重复[0,0]1次,行1次  
array([[0, 0]])  
>>> numpy.tile([0,0],(2,1))#在列方向上重复[0,0]1次,行2次  
array([[0, 0],  
       [0, 0]])  
>>> numpy.tile([0,0],(3,1))  
array([[0, 0],  
       [0, 0],  
       [0, 0]])  
>>> numpy.tile([0,0],(1,3))#在列方向上重复[0,0]3次,行1次  
array([[0, 0, 0, 0, 0, 0]])  
>>> numpy.tile([0,0],(2,3))<span style="font-family: Arial, Helvetica, sans-serif;">#在列方向上重复[0,0]3次,行2次</span>  
array([[0, 0, 0, 0, 0, 0],  
       [0, 0, 0, 0, 0, 0]])  

2, operator模块中的常用函数。例如使用 itemgetter() 从元组序列中获取指定的域值。

>>> inventory = [('apple', 3), ('banana', 2), ('pear', 5), ('orange', 1)]
>>> getcount = operator.itemgetter(1)
>>> map(getcount, inventory)
[3, 2, 5, 1]
>>> sorted(inventory, key=getcount)
[('orange', 1), ('banana', 2), ('apple', 3), ('pear', 5)]

3, python2中的iteritems()在python3中变为了items()
在Python2.x中,items( )用于 返回一个字典的拷贝列表【Returns a copy of the list of all items (key/value pairs) in D】,占额外的内存。
iteritems() 用于返回本身字典列表操作后的迭代【Returns an iterator on all items(key/value pairs) in D】,不占用额外的内存。

Python 3.x 里面,iteritems() 和 viewitems() 这两个方法都已经废除了,而 items() 得到的结果是和 2.x 里面 viewitems() 一致的。在3.x 里 用 items()替换iteritems() ,可以用于 for 来循环遍历。

4, matplotlib模块中的subplot()方法。

subplot(numRows, numCols, plotNum)  

subplot将整个绘图区域等分为numRows行* numCols列个子区域,然后按照从左到右,从上到下的顺序对每个子区域进行编号,左上的子区域的编号为1。如果numRows,numCols和plotNum这三个数都小于10的话,可以把它们缩写为一个整数,例如subplot(323)和subplot(3,2,3)是相同的。subplot在plotNum指定的区域中创建一个轴对象。如果新创建的轴和之前创建的轴重叠的话,之前的轴将被删除。

import matplotlib  
import matplotlib.pyplot as plt  

for i,color in enumerate("rgbyck"):  
    plt.subplot(321+i,axisbg=color)    
plt.show()  

效果如下:
效果

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值