机器学习实战-k近邻算法

k-近邻算法概述

k-近邻算法采用测量不同特征之间的距离方法进行分类

k-近邻算法(knn):存在一个样本数量集合,也称训练样本集,并且样本几张每个数据都存在标签,即我们知道样本集中每一数据与所属分类的对应关系。输入每一标签的新数据后,将新数据的每个特征与样本几种数据对应的特征进行比较,然后算法提取样本集中特征最相似数据的分类标签。通常k是不大于20的整数

使用knn算法分类爱情片和动作片,思路:先统计6部电影的打斗镜头,接吻镜头数,属于哪种电影类型,然后分析新电影样本与已知样本的距离,选择距离最近的前三部电影,发现都是爱情片,由此判断此部电影是爱情片。

1.1 使用python 导入数据

1.2 实施knn分类算法

#伪代码
1)计算已知类别数据集中的点与当前点之间的距离
2)按照距离递增次序排序
3)选取与当前点距离最小的k个点
4)确定前k个点所在类别的出现频率
5)返回前k个点出现频率最高的类别作为当前点的预测分类
#k近邻算法  k表示选择最近邻居的数目
def classify0(inX, dataSet, labels, k):
    dataSetSize = dataSet.shape[0]
    diffMat = tile(inX, (dataSetSize,1)) - dataSet
    sqDiffMat = diffMat**2
    sqDistances = sqDiffMat.sum(axis=1)
    distances = sqDistances**0.5  //距离计算
    sortedDistIndicies = distances.argsort()     
    classCount={}          
    for i in range(k):
        voteIlabel = labels[sortedDistIndicies[i]]
        classCount[voteIlabel] = classCount.get(voteIlabel,0) + 1
    #分类,返回分类标签
    sortedClassCount = sorted(classCount.iteritems(), key=operator.itemgetter(1), reverse=True)
    return sortedClassCount[0][0]

2.1准备数据:从文本文件中解析数据

#读取文件数据
def file2matrix(filename):
    fr = open(filename)
    arrayL = fr.readlines()
    numL = len(arrayL)
    omatrix = zeros((numL, 3))
    classvector = []
    index = 0
    for line in arrayL:
        line = line.strip()
        listform = line.split("\t")
        omatrix[index, :] = listform[0:3]
        classvector.append(int(listform[-1]))
        index += 1
    return omatrix, classvector

2.2 分析数据:使用matplotlib 创建散点图

2.3 准备数据:归一化数据

在处理不同取值范围的特征值时,通常采用归一化数据,使得每个特征的权重相等,避免某个特征值严重影响计算结果

#归一化值=(原来的值-min) /max-min
def autoNorm(dataset):
    minval = dataset.min(0)
    maxval = dataset.max(0)
    range = maxval - minval
    normSet = zeros(shape(dataset))
    m = dataset.shape[0]
    normSet = dataset - tile(minval, (m, 1))
    normSet = normSet / tile(range, (m, 1))
    return normSet, range, minval

2.4测试算法,作为完整程序验证分类器

def datingClassTest():
    hoRatio = 0.50      #hold out 10%
    datingDataMat,datingLabels = file2matrix('datingTestSet2.txt')       #load data setfrom file
    normMat, ranges, minVals = autoNorm(datingDataMat)
    m = normMat.shape[0]
    numTestVecs = int(m*hoRatio)
    errorCount = 0.0
    for i in range(numTestVecs):
        #输出第i个的错误率
        classifierResult = classify0(normMat[i,:],normMat[numTestVecs:m,:],datingLabels[numTestVecs:m],3)
        print("the classifier came back with: %d, the real answer is: %d" % (classifierResult, datingLabels[i]))
        if (classifierResult != datingLabels[i]): errorCount += 1.0
    print("the total error rate is: %f" % (errorCount/float(numTestVecs)))
    print(errorCount)

2.5 使用算法

3 手写识别系统

3.1 准备数据:将图像转化为测试向量

#32*32的图像格式化为1*1024的向量
def img2vector(filename):
    returnVect = zeros((1,1024))
    fr = open(filename)
    for i in range(32):
        lineStr = fr.readline()
        for j in range(32):
            returnVect[0,32*i+j] = int(lineStr[j])
    return returnVect

3.2 测试算法:使用k-近邻算法识别手写数字

主要是classify0() 函数,把输入参数,训练数据,训练数据分类拿到,再调用函数

def handwritingClassTest():
    hwLabels = []
    trainingFileList = listdir('trainingDigits')           #load the training set
    m = len(trainingFileList)
    trainingMat = zeros((m,1024))
    for i in range(m):
        fileNameStr = trainingFileList[i]
        fileStr = fileNameStr.split('.')[0]     #take off .txt
        #9_45.txt 得出标签是9
        classNumStr = int(fileStr.split('_')[0])
        hwLabels.append(classNumStr)
        trainingMat[i,:] = img2vector('trainingDigits/%s' % fileNameStr)
    testFileList = listdir('testDigits')        #iterate through the test set
    errorCount = 0.0
    mTest = len(testFileList)
    for i in range(mTest):
        fileNameStr = testFileList[i]
        fileStr = fileNameStr.split('.')[0]     #take off .txt
        classNumStr = int(fileStr.split('_')[0])
        vectorUnderTest = img2vector('testDigits/%s' % fileNameStr)
        #主要函数,由于文件中值都在0,1之间,因此不需要使用autoNorm()
        classifierResult = classify0(vectorUnderTest, trainingMat, hwLabels, 3)
        print("the classifier came back with: %d, the real answer is: %d" % (classifierResult, classNumStr))
        if (classifierResult != classNumStr): errorCount += 1.0
    print("\nthe total number of errors is: %d" % errorCount)
    print("\nthe total error rate is: %f" % (errorCount/float(mTest)))

 

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值
>