[机器学习实战] k-近邻算法

原理

存在一个样本数据集合,也称作训练样本集,并且样本集中每个数据都存在标签,即我们知道样本集中每一数据与所属分类的对应关系。输入没有标签的新数据后,将新数据的每个特征与样本集中数据对应的特征进行比较,然后算法提取样本集中特征最相似数据(最近邻)的分类标签。一般来说,我们只选择样本数据集中前k各最相似的数据,这就是k-近邻算法中k的出处,通常k是不大于20的整数。最后,选择k个最相似数据中出现次数最多的分类,作为新数据的分类。

k-近邻算法的优缺点

有点:精度高、对异常值不敏感、无数据输入假定
缺点:计算复杂度高、空间复杂度高
使用数据范围:数值型和标称型
通常k是不大于20的整数

k-近邻算法的一般流程

(1)收集数据:可以使用任何方法
(2)准备数据:距离计算所需要的数值,最好是结构化的数据格式
(3)分析数据:可以使用任何方法
(4)训练算法:此步骤不适用于k-近邻算法
(5)测试算法:计算错误率
(6)使用算法:首选需要输入样本数据和结构化的输出结果,然后运行k-近邻算法判定输入数据分别属于哪个分类,最后应用对计算出的分类执行后续的处理

k值的选择

k值越小,整体模型变得越复杂,预测结果对近邻的实例点敏感,容易发生过拟合。k值越大,模型变得简单,可以减小学习的估计误差,但学习的近似误差会增大。在应用中,k值一般取一个比较小的数值,通常采用交叉验证法来选取最优的k值。

常用函数

(1)对arr重复x行y列构成新的arr
      tile(arr, (x, y))
(2)对arr重复x列构成新的arr
     tile(arr, x)
(3)对矩阵纵向上求和
     mat.sum(axis=0)
(4)对矩阵横向求和
    mat.sum(axis=1)  
(5)对dict排序,选择第1列作为key(下标从0开始)
    sorted(classCount.items(), key=operator.itemgetter(1), reverse=True)
(6)对array进行排序,返回排序后的下标数组
    array.argsort()
(7)重新加载模块,模块有更新的情况下
    reload(module)
(8)对每一列,取最小值,形成新的array
    array.min(0)
(9)显示path目录下的所有文件
    listdir(path)
注意:NumPy库提供的数组操作并不支持Python自带的数组类型,因此在编写代码时要注意不要使用错误的数组类型

样例代码

DataUtil.py
1. 用于随机生成数据集
2. 用于随机生成测试向量
3. 用于归一化
4. 用于按照比例随机切分训练集和测试集
# -*- coding: utf-8 -*-

from numpy import *

class DataUtil:
    def __init__(self):
        pass

    def randomDataSet(self, row, column, classes):
        '''rand data set'''
        if row <= 0 or column <= 0 or classes <= 0:
            return None, None
        dataSet = random.rand(row, column)
        dataLabel = [random.randint(classes) for i in range(row)]
        return dataSet, dataLabel

    def file2DataSet(self, filePath):
        '''read data set from file'''
        f = open(filePath)
        lines = f.readlines()
        dataSet = None
        dataLabel = []
        i = 0
        for line in lines:
            items = line.strip().split('\t')
            if dataSet is None:
                dataSet = zeros((len(lines), len(items)-1))
            dataSet[i,:] = items[0:-1]
            dataLabel.append(items[-1])
            i += 1
        return dataSet, dataLabel

    def randomX(self, column):
        '''rand a vector'''
        return random.rand(1, column)[0]

    def norm(self, dataSet):
        '''normalize'''
        minVals = dataSet.min(0)
        maxVals = dataSet.max(0)
        ranges = maxVals - minVals
        m = dataSet.shape[0]
        return (dataSet - tile(minVals, (m, 1)))/tile(ranges, (m, 1))

    def spitData(self, dataSet, dataLabel, ratio):
        '''split data with ratio'''
        totalSize = dataSet.shape[0]
        trainingSize = int(ratio*totalSize)
        testingSize = totalSize - trainingSize

        # random data
        trainingSet = zeros((trainingSize, dataSet.shape[1]))
        trainingLabel = []
        testingSet = zeros((testingSize, dataSet.shape[1]))
        testingLabel = []
        trainingIndex = 0
        testingIndex = 0
        for i in range(totalSize):
            r = random.randint(1, totalSize)
            if (r <= trainingSize and trainingIndex < trainingSize) or testingIndex >= testingSize:
                trainingSet[trainingIndex,:] = dataSet[i,:]
                trainingLabel.append(dataLabel[i])
                trainingIndex += 1
            else:
                testingSet[testingIndex,:] = dataSet[i,:]
                testingLabel.append(dataLabel[i])
                testingIndex += 1
        return trainingSet, trainingLabel, testingSet, testingLabel

kNN.py
1. k-近邻算法的实现
# -*- coding: utf-8 -*-

import operator
from numpy import *

class kNN:
    def __init__(self):
        pass

    def classify(self, dataSet, dataLabel, vectorX, k):
        # data validate
        (row, column) = dataSet.shape
        if row <= 0 or column <= 0 or row != len(dataLabel) or column != len(vectorX) or k <= 0:
            return None, None

        # calculate distance and sort
        dataX = tile(vectorX, (row, 1))
        distance = (((dataX - dataSet)**2).sum(axis=1))**0.5
        sortedIndice = distance.argsort()

        # classify
        classCount = {}
        for i in range(k):
            if i >= row:
                break
            label = dataLabel[sortedIndice[i]]
            classCount[label] = classCount.get(label, 0) + 1

        # sort and return result
        return distance, sorted(classCount.iteritems(), key=operator.itemgetter(1), reverse=True)[0][0]

Test4knn.py
1. 用于测试k-近邻算法
# -*- coding: utf-8 -*-

from com.fighting.util.DataUtil import *
from com.fighting.knn.kNN import *
import matplotlib.pyplot as plt

def knn():
    '''test knn'''
    row, column, classes, k = (100, 5, 3, 10)

    # load data set
    dataUtil = DataUtil()
    dataSet, dataLabel = dataUtil.randomDataSet(row, column, classes)
    print 'dataSet: '
    print dataSet
    print 'dataLabel: '
    print dataLabel

    # normalize
    dataSet = dataUtil.norm(dataSet)
    print 'norm-dataSet:'
    print dataSet

    # plot the data
    fig = plt.figure()
    ax = fig.add_subplot(111)
    ax.scatter(dataSet[:,0], dataSet[:,1], 15*array(dataLabel), 15*array(dataLabel))
    plt.show()

    # random vector X
    vectorX = dataUtil.randomX(dataSet.shape[1])
    print 'vectorX: '
    print vectorX

    # classify
    knn = kNN()
    distance, clz = knn.classify(dataSet, dataLabel, vectorX, k)
    print 'distance: '
    print distance
    print 'clz=%d' % clz

def dating():
    '''test dating classify'''
    # load data set
    dataUtil = DataUtil()
    dataSet, dataLabel = dataUtil.file2DataSet('../../../datasets/knn/datingTestSet.txt')
    dataSet = dataUtil.norm(dataSet)

    # split training set and testing set
    ratio = 0.8
    trainingSet, trainingLabel, testingSet, testingLabel = dataUtil.spitData(dataSet, dataLabel, ratio)
    testingSize = testingSet.shape[0]

    # training and testing
    knn = kNN()
    for k in range(1, 11):
        error = 0
        for i in range(testingSize):
            distance, clz = knn.classify(trainingSet, trainingLabel, testingSet[i,], k)
            if clz != testingLabel[i]:
                error += 1
        print '%d, %.2f' % (k, error*1.0/testingSize)

def f2d():
    '''test file2dataset'''
    dataUtil = DataUtil()
    dataSet, dataLabel = dataUtil.file2DataSet('../../../datasets/knn/datingTestSet.txt')
    print 'dataSet:'
    print dataSet
    print 'dataLabel:'
    print dataLabel

if __name__ == '__main__':
    knn()
    #dating()
    #f2d()



评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值