机器学习实战_K近邻算法 —— 电影分类

一、数据参考

在这里插入图片描述

二、代码

import numpy as np
import operator


def createDataSet():
    """
    函数说明:创建数据集

    Parameters:
        None

    Returns:
        group - 数据集
        labels - 分类标签

    """
    # 七组二维特征
    group = np.array([[3, 104],
                      [2, 100],
                      [1, 81],
                      [101, 10],
                      [99, 5],
                      [98, 2],
                      [18, 90]])
    # 七组特征的标签
    labels = ['爱情片', '爱情片', '爱情片', '动作片', '动作片', '动作片', "未知"]
    return group, labels


def classify0(inX, dataSet, labels, k):
    """
    函数说明:kNN算法,分类器

    Parameters:
        inX - 用于分类的数据(测试集)(1*m向量)
        dataSet - 用于训练的数据(训练集)(n*m向量array)
        labels - 分类标准(n*1向量array)
        k - kNN算法参数,选择距离最小的k个点

    Returns:
        sortedClassCount[0][0] - 分类结果

    """
    # numpy函数shape[0]获取dataSet的行数
    dataSetSize = dataSet.shape[0]
    # 将inX重复dataSetSize次并排成一列,即将inX赋值dataSetSize行、1列
    diffMat = np.tile(inX, (dataSetSize, 1)) - dataSet   # tile:复制函数
    # 矩阵数乘:矩阵对应位置元素相乘(array()函数中矩阵的乘积可以使用np.matmul或者.dot()函数。而星号乘 (*)则表示矩阵对应位置元素相乘,与numpy.multiply()函数结果相同)
    sqDiffMat = diffMat ** 2  # 每个元素 ** 2
    # sum()所有元素相加,sum(0)列相加,sum(1)行相加
    sqDistances = sqDiffMat.sum(axis=1)
    # 开方,计算出距离
    distances = sqDistances ** 0.5  # 每个元素 ** 0.5
    # argsort函数返回的是distances值从小到大排序后的索引值
    sortedDistIndicies = distances.argsort()
    # 定义一个记录类别次数的字典
    classCount = {}
    # 选择距离最小的k个点
    for i in range(k):
        # 取出前k个元素的类别
        voteIlabel = labels[sortedDistIndicies[i]]
        # 字典的get()方法,返回指定键的值,如果值不在字典中返回0
        # 计算类别次数
        classCount[voteIlabel] = classCount.get(voteIlabel, 0) + 1
    # python3中用items()替换python2中的iteritems()
    # key = operator.itemgetter(1)根据字典的值进行排序
    # key = operator.itemgetter(0)根据字典的键进行排序
    # reverse降序排序字典
    sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True)
    print("sortedClassCount:", sortedClassCount)
    # 返回次数最多的类别,即所要分类的类别
    return sortedClassCount[0][0]


if __name__ == '__main__':
    group, labels = createDataSet()

    result = classify0([70, 5], group, labels, 3)
    print(result)

    result = classify0([9, 79], group, labels, 3)
    print(result)

三、运行结果

在这里插入图片描述

  • 2
    点赞
  • 7
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值