1.KNN原理
K-Nearest Neighbors
很简单,看图就一目了然了。
![db094a46ab803c8d93a8713c25be52d9.png](https://i-blog.csdnimg.cn/blog_migrate/c4d97fc3f4e90e01856867812a34d404.jpeg)
绿色和蓝色是已知数据,根据已知数据,我们想要知道,红色的点属于哪一类。
这是,我们选择的方法是:看距离红色最近的一个点(K=1)是属于哪一类,或者看距离红色最近的两个个点(K=1)或三个点(K=3)是属于哪一类,此时,我们需要在这些点里做投票,看看这个区域内,哪个颜色的点数量多。
我们以K=3为例来看:
![f253fa273f94728edea2ef1070a3459f.png](https://i-blog.csdnimg.cn/blog_migrate/8daf5d1a61d3d8cf74921a281b52dfee.jpeg)
在选择距离红色点最近的三个点作为参考时,绿色占两个,蓝色占一个,这是,我们就将红色归为绿色一类。
这也告诉我了我们,K在通常情况下要选择奇数,否则很容易出现两种情况数量相等的情况(如K为4,可能会出现两蓝两绿的情况)
2.KNN算法
KNN具体的实现方法
1.特征工程,Feature Engineering,把物体用向量、矩阵、张量的形式表示出来
2.标记号:每个物体的标签
3.计算两个物体之间的距离,最常用的就是欧式距离:
4.选择合适的K
- 首先,要理解什么是决策边界(非常重要):
- 决策边界决定线性分类器和非线性分类器。
- 决策边界不可以过于陡峭(过拟合)
看一个小例子
![b01b61b9a258298b2a85b268dd8a8c85.png](https://i-blog.csdnimg.cn/blog_migrate/2ff0ee1815086e9e9df8edda4af033de.jpeg)
根据距离橙/绿点的距离,填充背景。那么很明显,在右图我们可以看到清晰的决策边界。
- 用交叉验证方法(Cross Validation)选择K:
![926988ff3258ef78e965ac4a8c49bd6f.png](https://i-blog.csdnimg.cn/blog_migrate/290651892e9ae619feeb0ceec4467269.jpeg)
将所有数据分为训练数据(绿色)和测试数据(橙色),默认为75%和25%,可以自行调整。
交叉验证是将训练数据再次分配,我们以5折为例,就是说将交叉数据分成五份,每次都选取不同的数据作为验证数据(蓝色)。
首先验证K=1
每一组不同的验证数据都会得出一个准确度,求得五组准确度的平均值,就是K=1情况下的准确度。
![460e41c5d708cce4e843b39eedc209c1.png](https://i-blog.csdnimg.cn/blog_migrate/a0d15d9485217baa5498428ff9c6bee0.jpeg)
同理,当K=3时,方法同上,求出平均准确度
![23e1f70caf8e846c02716ac712d5b247.png](https://i-blog.csdnimg.cn/blog_migrate/2882f456c30cc6fe1e9b1a4e801fec11.jpeg)
求出两个准确度之后,看哪一个准确度更高,说明哪一个K值取得好。当然在实际中,交叉验证的不会只有K=1K=3两个值,经常都是交叉验证很多个K值,取其中准确度最高的为我们需要的K值。
当数据量少的时候,可以适当增加折数。
切记:不能用测试数据(橙色)来调参(K)!!!!!!!!
5.其他需要注意的地方
特征的缩放:特征中量纲差距过大,导致其中一个或几个特征在计算中基本起不到作用,这是我们需要特征的缩放。方法如下:
- 线性归一化(Min-Max-Normalization):一般情况Min=0,Max=1
公式为:
- 标准差标准化(Z-Score-Normalization):
公式为:
3.KNN代码
详细请看注释
1. 调用KNN函数来实现分类
数据采用的是经典的iris数据,是三分类问题
# 读取相应的库
2.从零开始自己写一个KNN算法
from
3.KNN的决策边界
import
![0dff3dc4dd3fbb0f15a38844b7dd8728.png](https://i-blog.csdnimg.cn/blog_migrate/2eba4b7cc773efbaf77d329c270d3a2a.jpeg)
#K=1 不平缓,陡峭,不稳定
4.KNN回归
import
# 数据转换
![b7ad9e49685accca60f7317fda5a428f.png](https://i-blog.csdnimg.cn/blog_migrate/d48d7720b8bc916287d353fa7e2a45fc.jpeg)
#sns.pairplot(df[['Construction Year', 'Days Until MOT', 'Odometer', 'Ask Price']], size=2)
![227914a584350dc20bd8c48117318c2a.png](https://i-blog.csdnimg.cn/blog_migrate/c9312f500ccd767283d36d70f116d73a.png)
[1199. 1199. 700. 899.]
knn
KNeighborsRegressor(algorithm='auto', leaf_size=30, metric='minkowski',
metric_params=None, n_jobs=None, n_neighbors=2, p=2,
weights='uniform')
pred = knn.predict(X_test)
pred
array([ 1.36676513, 1.36676513, -0.68269804, 0.13462294])
from sklearn.metrics import mean_absolute_error
mean_absolute_error(y_pred_inv, y_test_inv)
175.5
from sklearn.metrics import mean_squared_error
mean_squared_error(y_pred_inv, y_test_inv)
56525.5
y_pred_inv
array([1199., 1199., 700., 899.])
y_test_inv
array([[1300.],
[1650.],
[ 650.],
[ 799.]])
如果有兴趣,可以关注我的公众号(圈圈小姐),编程小白不定期分享自己的学习成果(以及生活旅行和德语小知识点),大家一起进步~~