机器学习算法学习笔记(1) KNN算法

##以下内容使用的数据集并非真实数据,只是为了方便理解kNN算法

kNN算法的思想很简单,我的理解如下:在已有数据集的情况下,对于一个新的需要预测的数据点,只需要去找到距离新点最近的k个已有数据点,根据这k个已有数据点的label投票决定新点的label。

本文采用的数据集为随机生成的:

# -*- coding: UTF-8 -*-
import numpy as np

datax1 = np.random.normal(20,3,5).reshape(5,1)
datax2 = np.random.normal(30,4,5).reshape(5,1)

datax = np.vstack([datax1,datax2])

datay1 = np.random.normal(50,4,5).reshape(5,1)
datay2 = np.random.normal(40,4,5).reshape(5,1)

datay = np.vstack([datay1,datay2])

data = np.hstack([datax,datay])
print(data)

由以上生成程序我得到了:

data=np.array([[23.94274904,52.39849734],
			 [25.21766745,55.00034639],
			 [22.76012006,53.54882548],
			 [16.75486411,49.73771554],
			 [22.24265087,49.51301099],
			 [31.6994235,35.19098192],
			 [28.21982591,43.57930998],
			 [27.72494664,29.22964265],
			 [28.20761154,40.50009104],
			 [27.35689442,35.94715885]])

而它们对应的label数组为:

label = np.array([0,0,0,0,0,1,1,1,1,1])

kNN算法

from math import sqrt
#现在有一个需要预测的点a
a = np.array([32.26801694,54.14229264])

#计算每个已有数据点到预测点a的欧拉距离
distance = []
for row in data:
	dis = sqrt(np.sum((row-a)**2))
	distance.append(dis)

#以上可以用一行代码解决
#distance=[sqrt(np.sum((row-a)**2)) for row in data]

#若k取6
k = 6
#得到按distance排序后的下标,并取前k位
argdis = np.argsort(distance)[:k]
label_list = label[argdis].tolist()
print(label_list)

执行上述代码,得到[0, 0, 0, 0, 0, 1],此为前6个距离最近的数据点的label。

最后统计得到预测结果

#转化成set去重
lset = set(label_list)
#生成统计字典
times = {item:label_list.count(item) for item in lset}
#获得字典中value最大的key值
predict = max(times,key=times.get)
print("预测值为:%s" % predict)

除了这种方法,这里也可以使用collections中的Counter类进行统计,Counter类有一个most_common方法,可以直接得出最后结论。如下

from collections import Counter
votes = Counter(label_list)     #注意这里是List,不是Array
predict = votes.most_common(1)[0][0]   #找到个数最多的所对应的label
print("预测值为:%s" % predict)

kNN算法是不需要训练过程的机器学习算法,这是它的特殊性。为了和其它算法统一,可以认为它的数据训练集就是模型本身。

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值