机器学习实战 - K-近邻算法

机器学习实战 - 2.1 K-近邻算法概述

k-近邻算法(KNN)的工作原理是:有一个样本数据集,注意是集合(说明是大量的数据,至少>1),并且每个数据样本都存在标签,如点(1, 1)对应标签A、点(-1, -1)对应标签B,在此我们可以将点(1, 1)的两个“1”称为标签,标签可以有有限个(我喜欢用概率论的定义称之为可列多),视具体情况而定。那么我们可以输入没有被分类的数据,将新数据的特征与样本集合中对应位置的特征进行比较,然后提取样本集中特征最相似数据(最近邻)的分类当成此数据的分类标签。之所以称之为k-近邻是因为我们一般选取最相似的前k个数据,然后统计前k个数据中出现次数最多的变便签当成此数据的标签,选取前k个数据有利于我们降低对此数据分类的误判率,提高正确率。

我们将所使用的函数都放在KNN.py文件里,首先我们讲解刚才提到的问题,通过二维平面上的点来对便签进行分类(这里使用的是python3.6):

from numpy import *
import operator


def createDataSet():
    group = array([[1.0, 1.1], [1.0, 1.0], [0, 0], [0, 0.1]])
    labels = ['A', 'A', 'B', 'B']
    return group, labels


def classify0(inX, dataSet, labels, k):
    dataSetSize = dataSet.shape[0]
    diffMat = tile(inX, (dataSetSize, 1)) - dataSet
    sqDiffMat = diffMat**2
    sqDistance = sqDiffMat.sum(axis=1)
    distance = sqDistance**0.5
    sortedDistIndicies = distance.argsort()
    classCount = {}
    for i in range(k):
        voteIlabel = labels[sortedDistIndicies[i]]
        classCount[voteIlabel] = classCount.get(voteIlabel, 0) + 1
    sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True)
    return sortedClassCount[0][0]

下面是对代码中的个人详解,如有不足之处还望指出

from numpy import *
import operator

"""生成数据集及对应的标签"""
def createDataSet():
    group = array([[1.0, 1.1], [1.0, 1.0], [0, 0], [0, 0.1]])
    labels = ['A', 'A', 'B', 'B']
    return group, labels


"""输入向量inX, 数据集dataSet, 标签labels, 取前k个数"""
"""假设我们计算特征值之间的相似度使用的是距离,欧氏距离 d = sqrt((Ax-Bx)*(Ax-Bx) + (Ay-By)*(Ay-By))"""
def classify0(inX, dataSet, labels, k):

    """行数, 计算行使用shape[0]、 计算列使用shape[1]"""
    dataSetSize = dataSet.shape[0]

    """
        函数格式tile(A,reps), 理解为A重复reps次, 如tile((1, 3), (3, 2))
        得到的答案是[[1 3 1 3]
                    [1 3 1 3]
                    [1 3 1 3]]
        形如(A,(x, y))即为A重复x行,每行重复y次
        这里的tile(inX, (dataSetSize, 1))是把输入的数据生成dataSetSiz行每行1个
        假设输入的是inX是[0, 0],dataSet[[1.0, 1.1], [1.0, 1.0], [0, 0], [0, 0.1]]是那么diffMat = tile(inX, (dataSetSize, 1)) - dataSet
        得到的结果就是[[0-1.0, 0-1.1]
                      [0-1.0, 0-1.0]
                      [0-0, 0-0]
                      [0-0, 0-0.1]]
        得到的diffMat是一个矩阵,数据与数据集相减之后的矩阵,为计算距离作准备
    """
    diffMat = tile(inX, (dataSetSize, 1)) - dataSet

    """平方"""
    sqDiffMat = diffMat**2

    """求和,平时用的sum应该是默认的axis=0 就是普通的相加,而当加入axis=1以后就是将一个矩阵的每一行向量相加,这里是计算每一行的欧氏距离的平方"""
    sqDistance = sqDiffMat.sum(axis=1)

    """将得到的结果开根号就是就是距离了"""
    distance = sqDistance**0.5

    """
        argsort根据大小排序得到的结果是索引值
        如(0.2, 0.3, 0.1)
        得到的结果是(2, 0, 1),之后我们可以利用这个索引值找到相应的标签
    """
    sortedDistIndicies = distance.argsort()

    classCount = {}

    """选取前k个距离最小的,通过索引值得到标签"""
    for i in range(k):
        voteIlabel = labels[sortedDistIndicies[i]]
        classCount[voteIlabel] = classCount.get(voteIlabel, 0) + 1

    """最后返回得到相应次数最多的标签, sorted函数的参数列表代表的含义可自行查找, python2第一个参数是classCount.iteritems()"""
    sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True)
    return sortedClassCount[0][0]

我们新建一个文件dataTest.py对数据进行测试:

import KNN

group, labels = KNN.createDataSet()

print(KNN.classify0([1.2, 0.8], group, labels, 3))
"""输出A"""
print(KNN.classify0([0.2, -0.1], group, labels, 3))
"""输出B"""

这是书本上简单的例子,主要是描述KNN的思想

附上课本“使用k-近邻算法改进约会网站的配对效果”代码

def file2matrix(filename):
    fr = open(filename)
    arrayOLines = fr.readlines()
    numberOfLines = len(arrayOLines)
    returnMat = zeros((numberOfLines, 3))
    classLabelVector = []
    index = 0
    for line in arrayOLines:
        line = line.strip()
        listFromLine = line.split('\t')
        returnMat[index, :] = listFromLine[0:3]
        classLabelVector.append(listFromLine[-1])
        index += 1
    return returnMat, classLabelVector


def createScatterDiagram(datingDataMat, datingLabels):
    fig = plt.figure()
    ax = fig.add_subplot(111)
    ax.scatter(datingDataMat[:, 0], datingDataMat[:, 1])
    plt.show()


def autoNorm(dataSet):
    minValue = dataSet.min(0)
    maxValue = dataSet.max(0)
    ranges = maxValue - minValue
    normDataSet = zeros(shape(dataSet))
    m = dataSet.shape[0]
    normDataSet = dataSet - tile(minValue, (m, 1))
    normDataSet = normDataSet / tile(ranges, (m, 1))
    return normDataSet, ranges, minValue


def datingClassTest():
    hoRatio = 0.10
    datingDataMat, datingLabels = file2matrix('datingTestSet2.txt')
    normMat, ranges, minValue = autoNorm(datingDataMat)
    m = normMat.shape[0]
    numTestVecs = int(m * hoRatio)
    errorCount = 0.0
    for i in range(numTestVecs):
        classifierResult = classify1(normMat[i, :], normMat[numTestVecs:m, :], datingLabels[numTestVecs:m], 3)
        print("the classifier came vack whit: %s, the real answer is: %s" % (classifierResult, datingLabels[i]))
        if (classifierResult != datingLabels[i]):
            errorCount += 1.0
    print("Error rate is %f" % (errorCount / numTestVecs))


def classify1(inX, dataSet, labels, k):
    dataSetSize = dataSet.shape[0]
    diffMat = tile(inX, (dataSetSize, 1)) - dataSet
    sqDiffMat = diffMat ** 2
    sqDistance = sqDiffMat.sum(axis=1)
    distance = sqDistance ** 0.5
    sortedDistance = distance.argsort()
    classCount = {}
    for i in range(k):
        votaIlabel = labels[sortedDistance[i]]
        classCount[votaIlabel] = classCount.get(votaIlabel, 0) + 1
    sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True)
    return sortedClassCount[0][0]


def classifyPerson():
    resultList = ['not at all', 'in small doses', 'in large doses']
    ffMiles = float(input("requent flier miles earned per year?"))
    percentTats = float(input("percentage of time spent palying video games?"))
    iceCream = float(input("liters of ice cream consumed per year?"))
    inArr = array([ffMiles, percentTats, iceCream])
    datingDataMat, datingLabels = file2matrix("datingTestSet2.txt")
    normMat, ranges, minValue = autoNorm(datingDataMat)
    classifierResult = classify1((inArr - minValue) / ranges, normMat, datingLabels, 3)
    print("You will probably like this person: ", resultList[int(classifierResult) - 1])


def classifyPersonFromFile():
    resultList = ['not at all', 'in small doses', 'in large doses']
    datingDataMat, datingLabels = file2matrix("datingTestSet2.txt")
    normMat, ranges, minValue = autoNorm(datingDataMat)
    TestdatingDataMat= datingDataMat
    matSize = TestdatingDataMat.shape[0]
    output = open("a.txt", "w")
    for i in range(matSize):
        ffMiles = float(TestdatingDataMat[i][0])
        percentTats = float(TestdatingDataMat[i][1])
        iceCream = float(TestdatingDataMat[i][2])
        inArr = array([ffMiles, percentTats, iceCream])
        classifierResult = classify1((inArr - minValue) / ranges, normMat, datingLabels, 3)
        output.write("%-8s %-10s %-9s %s\n" % (str(int(ffMiles)), str(percentTats), str(iceCream), str(resultList[int(classifierResult) - 1])))
    output.close()

附上课本“手写识别系统代码”

def img2vector(filename):
    returnVector = zeros((1, 1024))
    fr = open(filename)
    for i in range(32):
        lineStr = fr.readline()
        for j in range(32):
            returnVector[0, i*32 + j] = int(lineStr[j])
    return returnVector


def handWritingClassTest():
    hwLabels = []
    trainingFileList = os.listdir('trainingDigits')
    length = len(trainingFileList)
    trainingMat = zeros((length, 1024))
    for i in range(length):
        fileName = trainingFileList[i]
        fileLabel = fileName.split("_")[0]
        hwLabels.append(fileLabel)
        trainingMat[i, :] = img2vector('trainingDigits/%s' % fileName)
    testFileList = os.listdir('testDigits')
    testLength = len(testFileList)
    errorCount = 0.0
    for i in range(testLength):
        testFileName = testFileList[i]
        testVector = img2vector('testDigits/%s' % testFileName)
        label = testFileName.split("_")[0]
        classifierResult = classify1(testVector, trainingMat, hwLabels, 5)
        if(classifierResult!=label):
            errorCount += 1.0
        print("the classifier came vack whit: %s, the real answer is: %s" % (classifierResult, label))
    print("the total error number is: %s" % (errorCount))
    print("the rate is: %s" % (errorCount/float(testLength)))
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值