机器学习(0)-K-近邻算法(KNN)

优缺点和适用范围

  • 优点:精度高、对异常值不敏感、无数据输入假定。
  • 缺点:计算复杂度高、空间复杂度高。
  • 适用数据范围:数值型和标称型(离散型数据,变量的结果只在有限目标集中取值)。

原理/数学推理过程

  • 存在数据集,且每个数据存在标签,输入没有标签的数据后,计算该数据到所有其他已知类别数据的距离,排序,并取最近的k个(k<20),选择k个数据中出现次数最多的类别作为输入数据的分类

过程代码实现

  • 收集数据:可以使用任何方法。
  • 准备数据:距离计算所需要的数值,最好是结构化的数据格式。
  • 分析数据:可以使用任何方法。
  • 训练算法:此步骤不适用于k近邻算法。
  • 测试算法:计算错误率。
  • 使用算法:首先需要输入样本数据和结构化的输出结果,然后运行k近邻算法判定输入数据分别属于哪个分类,最后应用对计算出的分类执行后续的处理。

数据和源码

  • 最简单的例子
import numpy as np 
import operator

def createDataSet():
    ground = np.array([[1.0, 1.1], [1.0, 1.0], [0, 0], [0, 0.1]])
    labels = ['A', 'A', 'B', 'B']
    return ground, labels

# 分类算法
def classify0(inX,dataSet,labels,k):
    # 数据长度
    dataSetSize = dataSet.shape[0]
    # 计算inX点与其他所有点的距离,tile方法把inX点修改为矩阵
    diffMat = np.tile(inX, (dataSetSize,1)) - dataSet
    # 计算平方
    sqDiffMat = diffMat ** 2
    # 求和
    sqDistances = sqDiffMat.sum(axis = 1)
    # 求根号
    distance = sqDistances **0.5
    # 排序,并取其index存值
    sortedDistIndicies = distance.argsort()
    classCount = {}
    for i in range(k):
        voteLabel = labels[sortedDistIndicies[i]]
        # 有则+1,无则生成一个
        classCount[voteLabel] = classCount.get(voteLabel,0) + 1
    # 字典的排序
    sortedClassCount = sorted(classCount.items(),key=operator.itemgetter(1),reverse=True)
    # python2下使用classCount.iteritems()代替classCount.items()
    return sortedClassCount[0][0]


if __name__=='__main__':
    group, labels = createDataSet()
    print(classify0([3,3] ,group, labels, 3))
  • 改进约会网站配对效果
import operator
import numpy as np
import matplotlib
import matplotlib.pyplot as plt

def createDataSet():
    ground = np.array([[1.0, 1.1], [1.0, 1.0], [0, 0], [0, 0.1]])
    labels = ['A', 'A', 'B', 'B']
    return ground, labels

# 分类算法
def classify0(inX, dataSet, labels, k):  

    # 数据长度
    dataSetSize = dataSet.shape[0]
    # 计算inX点与其他所有点的激励,tile方法把inX点修改为矩阵
    diffMat = np.tile(inX, (dataSetSize,1)) - dataSet
    # 计算平方
    sqDiffMat = diffMat ** 2
    # 求和
    sqDistances = sqDiffMat.sum(axis = 1)
    # 求根号
    distance = sqDistances **0.5
    # 排序,并取其index存值
    sortedDistIndicies = distance.argsort()
    classCount = {}
    for i in range(k):
        voteLabel = labels[sortedDistIndicies[i]]
        # 有则+1,无则生成一个
        classCount[voteLabel] = classCount.get(voteLabel,0) + 1
    # 字典的排序
    sortedClassCount = sorted(classCount.items(),key=operator.itemgetter(1),reverse=True)
    return sortedClassCount[0][0]


# 文本转换为数据
def file2matrix(filename):
    fr = open(filename)
    arrayOlines = fr.readlines()
    numberOfLines = len(arrayOlines)
    # 初始化0矩阵
    returnMat = np.zeros((numberOfLines,3))
    classLabelVector = []
    index = 0 
    for line in arrayOlines:
        # strip() 方法用于移除字符串头尾指定的字符(默认为空格)。
        line = line.strip()
        listFromLine = line.split('\t')
        returnMat[index,:] = listFromLine[0:3]
        # 取倒数第一个数
        classLabelVector.append(int(listFromLine[-1]))
        index += 1
    return returnMat, classLabelVector

# 归一化操作
def autoNorm(dataSet):
    minVals = dataSet.min(0)
    maxVals = dataSet.max(0)
    ranges = maxVals-minVals
    normDataSet = np.zeros(np.shape(dataSet))
    m = dataSet.shape[0]
    normDataSet = dataSet-np.tile(minVals,(m,1))
    normDataSet = normDataSet/np.tile(ranges,(m,1))
    return normDataSet,ranges,minVals

# 测试函数
def datingClassTest():
    # 取前面10%的数据用作测试数据
    hoRatio =0.10
    datingDataMat,datingLabels = file2matrix('datingTestSet2.txt')
    normMat, ranges, minVals = autoNorm(datingDataMat)
    m = normMat.shape[0]
    numTestVecs = int(m*hoRatio)
    errorCount = 0.0
    for i in range(numTestVecs):
        classifirerResult = classify0(normMat[i,:], normMat[numTestVecs:m,:], datingLabels[numTestVecs:m],3)
        print('判断类别是%d,正确答案是%d' % (classifirerResult, datingLabels[i]))
        if(classifirerResult != datingLabels[i]):
            errorCount +=1.0
        pass
    pass
    print('错误的个数是%d,错误率是%f' % (errorCount,errorCount/numTestVecs))
if __name__=='__main__':
    # group, labels = createDataSet()
    # classify0([3,3] ,group, labels, 3)
    # datingDataMat,datingLabels = file2matrix('datingTestSet2.txt')
    # datingDataMat,ranges,minVals = autoNorm(datingDataMat)
    # fig = plt.figure()
    # ax = fig.add_subplot(111)
    # ax.scatter(datingDataMat[:,0],datingDataMat[:,1],5.0*np.array(datingLabels),15.0*np.array(datingLabels))
    # plt.show()
    datingClassTest()

  • 数据
    链接: https://pan.baidu.com/s/1pLVzIcF 密码: ay45
  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值