chapter2--KNN算法

KNN算法是懒惰的学习算法,没有明显的训练过程,预测时只需要使用已经有标注(分类学习)的训练数据即可

适用于多分类的学习任务

from numpy import*
import operator
import pdb
import matplotlib
import matplotlib.pyplot as plt
from os import listdir



#测试数据
def createDataSet():
    group = array([[1.0,1.1],[1.0,1.0],[0,0],[0,0.1]])
    labels = ['A','A','B','B']
    return group,labels


def classify0(inX, dataSet, labels, k):
    '''
    K-邻近算法:
    inX:目标点
    dataSet:数据集
    labels:数据的标签,label的列数和dataSet一样
    K:选取的K个邻近值

    返回inx的属性
    '''
    dataSetSize = dataSet.shape[0]  # 数据集大小

    # 目标点到k邻近点的距离
    diffMat = tile(inX, (dataSetSize, 1)) - dataSet
    sqDiffMat = diffMat ** 2  # 平方
    sqDistances = sqDiffMat.sum(axis=1)  # 求和
    distances = sqDistances ** 0.5  # 开方

    sortedDistIndicies = distances.argsort()  # 距离排序,返回的是distance排序后的索引

    # 统计K个点所属类别
    classCount = {}
    for i in range(k):
        voteIlabel = labels[sortedDistIndicies[i]]
        classCount[voteIlabel] = classCount.get(voteIlabel, 0) + 1
    sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True)
    # 返回频率最高的标签类别
    return sortedClassCount[0][0]


#将读取文件,将文件转化为numpy数据
def file2matrix(filename):
    fr = open(filename)
    arrayOLines = fr.readlines()
    numberOfLines = len(arrayOLines) #文件的行数
    returnMat = zeros((numberOfLines,3)) #初始化矩阵
#     pdb.set_trace()
    classLabelVector = []
    index = 0
    for line in arrayOLines:
        line = line.strip()
        listFromLine = line.split('\t')
        returnMat[index,:] = listFromLine[0:3]
        classLabelVector.append(int(listFromLine[-1]))
        index += 1
    return returnMat,classLabelVector


#归一化数据
def autoNorm(dataSet):
#     pdb.set_trace()
    minVals = dataSet.min(0) #行最大,若参数为1 则是列最大
    maxVals = dataSet.max(0)
    ranges = maxVals - minVals
    normDataSet = zeros(shape(dataSet))
    m = dataSet.shape[0] #dataSet的行数,若参数为1则为列数
    normDataSet = dataSet - tile(minVals,(m,1)) #计算每个数据与最小数据之间的差值
    normDataSet = normDataSet/tile(ranges,(m,1))#归一化
    return normDataSet,ranges,minVals


#绘图
def data_plt(datingDataMat,datingLabels):
    fig = plt.figure()
    ax = fig.add_subplot(111)
    ax.scatter(datingDataMat[:, 1], datingDataMat[:, 2], 15.0 * array(datingLabels), 15.0 * array(datingLabels))
    plt.show()


#分类器针对约会网站的测试代码
def datingClassText(datingDataMat,datingLabels):
    hoRation = 0.10 #10%测试集
    m = datingDataMat.shape[0]
    numTestVecs = int(m*hoRation) #测试集总量
    errorCount = 0.0
    for i in range(numTestVecs):
        '''
        classify0参数
        (预测点,数据集,标注,K的取值)
        '''
        classifierResult = classify0(datingDataMat[i,:],datingDataMat[numTestVecs:m,:],datingLabels[numTestVecs:m],3)
        print(classifierResult,datingLabels[i])
        if(classifierResult != datingLabels[i]):
            errorCount += 1.0 #错误分类计数
    print(errorCount/float(numTestVecs)) #错误分类概率


#约会网站预测函数
def classifyPerson(datingDataMat,datingLabels,ranges,minVals):
    resultList = ['not at all','in small doses','in large doses']
    #输入待测试数据,并将测试数据转换成numpy数据
    percentTats = float(input("每年玩游戏的时间"))
    ffMiles = float(input("每年飞行公里数"))
    iceCream = float(input("每年冰激凌消耗数"))
    inArr = array([ffMiles,percentTats,iceCream])
    #使用KNN算法
    classifierResult = classify0((inArr-minVals)/ranges,datingDataMat,datingLabels,3)
    #输出预测结果
    print('the result is :',resultList[classifierResult-1])


'''
手写识别系统
数据解释:trainingDigits是训练数据,testDigits是测试数据
'''
#每次只读一张图片(一个txt文件),将每张图片合成的数据,改成1*1024的numpy矩阵
def img2vector(filename):
    returnVect = zeros((1,1024))
    fr = open(filename)
    for i in range(32):
        lineStr = fr.readline()
        for j in range(32):
            returnVect[0,32*i+j] = int(lineStr[j])
    return returnVect

#识别程序
def handwritingClassTest(traindir,testdir):

    '''
     读取训练数据,生成训练集
    '''
    hwLabels = [] #用于存储每张图片所表示的数字,即数据属性
    trainingFileList = listdir(traindir) #读取指定文件下下所有文件
    m = len(trainingFileList)#计算总文件数目
    trainingMat = zeros((m,1024))#存储所有图片数字矩阵,m*1024,
    for i in range(m):#"0_7.txt" 文件名样式
        fileNameStr = trainingFileList[i] #第i个文件的文件名
        fileStr = fileNameStr.split('.')[0] #取文件名的前面,即去掉txt
        classNumStr = int(fileStr.split('_')[0]) #从文件名获取当前文件所保存的图像表示的数字
        hwLabels.append(classNumStr)
        path = '{}/{}'.format(traindir, fileNameStr)
        trainingMat[i:] = img2vector(path) #循环读入每个文件

    '''
    读取测试数据,对每个数据进行测试
    '''
    testFileList = listdir(testdir)
    errorCount = 0.0
    mTest = len(testFileList)
    for i in range(mTest):
        fileNameStr = testFileList[i]
        fileStr = fileNameStr.split('.')[0]
        classNumStr = int(fileStr.split('_')[0])
        vectorUnderTest = img2vector('{}/{}'.format(testdir,fileNameStr))
        classifierResult = classify0(vectorUnderTest,trainingMat,hwLabels,3) #进行预测
        if classifierResult != classNumStr: #错误预测结果计数
            errorCount += 1
        print(" {},   {}".format(classifierResult,classNumStr))
    print('the error rate is:',errorCount/float(mTest)) #错误预测率



if __name__ == '__main__':
    # group,labels = createDataSet()
    # datingDataMat,datingLabels = file2matrix('datingTestSet2.txt')
    # datingDataMat,ranges,minVals = autoNorm(datingDataMat)
    # # data_plt(datingDataMat,datingLabels)
    # datingClassText(datingDataMat,datingLabels)
    # classifyPerson(datingDataMat,datingLabels,ranges,minVals)
    # img2vector('digits/trainingDigits/0_1.txt')
    path_train = 'digits/trainingDigits'
    path_test = 'digits/testDigits'
    handwritingClassTest(path_train,path_test)
    # pdb.set_trace()

 

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
【1】项目代码完整且功能都验证ok,确保稳定可靠运行后才上传。欢迎下载使用!在使用过程中,如有问题或建议,请及时私信沟通,帮助解答。 【2】项目主要针对各个计算机相关专业,包括计科、信息安全、数据科学与大数据技术、人工智能、通信、物联网等领域的在校学生、专业教师或企业员工使用。 【3】项目具有较高的学习借鉴价值,不仅适用于小白学习入门进阶。也可作为毕设项目、课程设计、大作业、初期项目立项演示等。 【4】如果基础还行,或热爱钻研,可基于此项目进行二次开发,DIY其他不同功能,欢迎交流学习。 【注意】 项目下载解压后,项目名字和项目路径不要用中文,否则可能会出现解析不了的错误,建议解压重命名为英文名字后再运行!有问题私信沟通,祝顺利! 基于C语言实现智能决策的人机跳棋对战系统源码+报告+详细说明.zip基于C语言实现智能决策的人机跳棋对战系统源码+报告+详细说明.zip基于C语言实现智能决策的人机跳棋对战系统源码+报告+详细说明.zip基于C语言实现智能决策的人机跳棋对战系统源码+报告+详细说明.zip基于C语言实现智能决策的人机跳棋对战系统源码+报告+详细说明.zip基于C语言实现智能决策的人机跳棋对战系统源码+报告+详细说明.zip基于C语言实现智能决策的人机跳棋对战系统源码+报告+详细说明.zip基于C语言实现智能决策的人机跳棋对战系统源码+报告+详细说明.zip基于C语言实现智能决策的人机跳棋对战系统源码+报告+详细说明.zip基于C语言实现智能决策的人机跳棋对战系统源码+报告+详细说明.zip基于C语言实现智能决策的人机跳棋对战系统源码+报告+详细说明.zip基于C语言实现智能决策的人机跳棋对战系统源码+报告+详细说明.zip基于C语言实现智能决策的人机跳棋对战系统源码+报告+详细说明.zip
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值