CS231n 02- KNN

本文介绍了图像分类任务中的KNN算法,包括K近邻分类器的工作原理、距离计算方法(L1和L2距离)、超参数选择的影响以及KNN的缺点。还探讨了线性分类器的适用场景和限制,强调了预测速度与训练复杂度的权衡。
摘要由CSDN通过智能技术生成

Image classification task

机器学习:基于数据的解决方案

步骤:1.收集大量的图片并打上标签 2.使用机器学习算法来训练一个分类器 3.用新的图片来评估分类器的好坏

Nearest Neighbor Classifier 近邻分类器

https://cdn.jsdelivr.net/gh/QingYuAnWayne/PicStorage//20210307154602.png

记忆数据和标签并训练出模型,再把测试集图片预测为最有可能的标签

通过距离测量来比较图片

L1 距离 ,相当于得到的是差异性的值

https://cdn.jsdelivr.net/gh/QingYuAnWayne/PicStorage//20210307154902.png

算法:

https://cdn.jsdelivr.net/gh/QingYuAnWayne/PicStorage//20210307155059.png

在训练阶段,相当于只是记忆了所有图片的像素矩阵。

在预测阶段,对每一张测试图片,找到距离最小的训练集中的图片,把该图片的标签作为预测的标签。

然而,非常明显的是,假设我们有N个测试样本,训练的时间复杂度是O(1),而预测的时间复杂度是O(N)

对于实际使用中,我们更希望是训练的时间很长,而预测的时间很短,因此这种算法很辣鸡

KNN K近邻算法

相比于之前的从距离最近的数据找标签,KNN是找到K个最近的点的标签,再进行投票来预测这个点的标签。

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-3dZ6TSzF-1618536459940)(https://cdn.jsdelivr.net/gh/QingYuAnWayne/PicStorage//20210307155805.png)]

KNN的距离计算量度

  1. L1距离(曼哈顿距离)

    计算出样本每个维度的坐标之差的绝对值相加,计算出差异。

    https://cdn.jsdelivr.net/gh/QingYuAnWayne/PicStorage//20210307155930.png

  2. L2距离(欧几里得距离)

    https://cdn.jsdelivr.net/gh/QingYuAnWayne/PicStorage//20210307160236.png

如果点的每个维度的数值具有明确意义(例如企业中员工的数据:薪资、年龄、职位等等),那么L1距离更合适。但是,缺点是如果坐标轴旋转的话,那么结果会随之改变。

而L2距离更合适在点的每个维度的数值并没有什么实际的意义的情况下。如果上述的例子用L2距离,那么这些数据就没有意义了。

Hyperparameters 超参数

k值和距离量度的选择对模型的准确度有很大影响。超参数是只和算法本身有关的参数。

选择超参数的方法很大程度上依赖于模型要预测的事件或物体。需要全部尝试,来找出最合适的方案。

方法1:

在所有数据集上测试并找到最好的。 但显然K=1永远是最好的,pass

https://cdn.jsdelivr.net/gh/QingYuAnWayne/PicStorage//20210307161222.png

方法2:

将数据集分为训练数据和测试数据,通过训练数据和测试数据的表现来选择。但是最后得到的结果只是针对于测试数据的,不清楚算法对于未知的数据是怎么样的结果,鲁棒性太差,pass

https://cdn.jsdelivr.net/gh/QingYuAnWayne/PicStorage//20210307161444.png

方法3:(比较好的方案)

将数据集分为训练、验证、测试数据。在测试集上找到表现最好的参数,并通过测试数据来评估模型准确率。

https://cdn.jsdelivr.net/gh/QingYuAnWayne/PicStorage//20210307161556.png

方法4:

交叉验证。将数据集分组,每组(除了测试集)都会当作一个验证集,然后把每组的结果平均,找到最优参数。但是,太消耗算力和时间,通常只在小规模数据中使用。

https://cdn.jsdelivr.net/gh/QingYuAnWayne/PicStorage//20210307161641.png

KNN的缺点

  1. 不会用在像素距离(图像识别)上。因为对像素值的距离量度并不能得到有用的信息。并且预测时间太长了。

    在以下图片中,每个的L2距离都和原始图片一样,但其实看起来是有差异的,机器却会认为他们都是一样的。

    https://cdn.jsdelivr.net/gh/QingYuAnWayne/PicStorage//20210307162159.png

  2. 维度灾难:维度的上升会让数据集的数量指数倍的扩大。 因为KNN算法实质上是将样本空间分成几块区域,根据新的数据所在的区域确定该数据的标签。因此希望训练数据越紧密越好,如果过于分散,那么就会导致算法不能很好的识别。所以当维度越多时,就会导致需要的训练集数据的数量越多,越紧密,可能没有足够的数据集来满足这些条件。

https://cdn.jsdelivr.net/gh/QingYuAnWayne/PicStorage//20210307163305.png

Linear Classifier 线性分类器

https://cdn.jsdelivr.net/gh/QingYuAnWayne/PicStorage//20210309160025.png

原图像是一个32323 的数组,其中3是代表RGB三个通道。矩阵乘后最终会得到10个不同标签的分数,某一类相对于其他类的分数越高,代表该类的可能性越大。

f(x, W) = Wx + b

其中,f是10x1的矩阵,W是10x3072的矩阵,x是3072x1的矩阵。b是bias。如果没有bias,那么当x取0时,每条分类线都会经过原点。

https://cdn.jsdelivr.net/gh/QingYuAnWayne/PicStorage//20210309161006.png

线性分类器更像是用一条线(二维平面上)把图片分割成不一样的区域,从而由新的图片所在的区域来预测该图片的标签是什么

线性分类器有一些无法使用的地方,由于必须要找到一条直线来进行分类,那么就无法处理以下问题

https://cdn.jsdelivr.net/gh/QingYuAnWayne/PicStorage//20210309162144.png

其次,线性分类器的W每一行都代表一个类别的标签,因此只能根据一个模板来训练(我觉得找不到两个模板来确定一条直线,所以不行),因此训练过程还是简单的。由于线性分类器的工作原理其实和KNN很相似,都是找到点和模板的相似度来得出分数(在线性分类器中,如果模板是红色的车停在绿色的草地上,那么分类器会更偏向于选择红色在中间,绿色围绕的图片),只不过测试速度变快了(因为不用一个一个比较,而是通过权重矩阵W计算出分数)。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值