Linear Classification
在上一讲里,我们介绍了图像分类问题以及一个简单的分类模型K-NN模型,我们已经知道K-NN的模型有几个严重的缺陷,第一就是要保存训练集里的所有样本,这个比较消耗存储空间;第二就是要遍历所有的训练样本,这种逐一比较的方式比较耗时而低效。
现在,我们要介绍一种更加强大的图像分类模型,这个模型会很自然地引申出神经网络和Convolutional Neural Networks(CNN),这个模型有两个重要的组成部分,一个是score function,将原始数据映射到输出变量;另外一个就是loss function,衡量预测值与真实值之间的误差。
我们先看模型的第一部分,定义一个score function,将图像的像素值,映射到一个输出变量,这个输出变量表示图像属于每一类的置信度或者说概率,我们假设有一批训练图像,
xi∈RD
,每一个训练样本都有一个类标签
yi
,其中,
i=1,2,...N
,
yi∈{1,2,...K}
,就是说,我们有
N
个训练样本,这
在上面的表达式中,
xi
是一个高维向量,包含一幅图像的所有像素,将图像从
m×n×3
变成
D×1
,矩阵
W
(
为了能够视觉化这个过程,我们假设图像是只有四个像素(实际情况一般至少是几千个像素),将图像变成一个列向量然后与权值
W
相乘,在加上偏移向量
下图展示了线性分类模型对图像分类的过程,因为我们不能将高维向量可视化,所以我们假设在二维平面观看这些图像,那么线性分类模型在各个类别之间的边界就有可能如下图所示:
从上面可以看出,
W
的每一行都相当于某一类的分类器,从几何意义上看,如果我们改变
之前我们做运算和训练的时候,都是利用图像的原始数据,一般来说,我们需要做一些预处理,我们会将一个训练集里的所有样本做归一化。比如图像,将图像从[0,255]映射到
[-1,1]的范围,而且减去均值向量,保证训练集的均值为0。
我们已经介绍了score function,现在我们要介绍线性分类模型的另外一个重要组成部分:loss function,或者成为cost function,这个用来衡量预测值与目标值之间的误差。定义loss function的方式有很多,这里我们先介绍一种经常使用的loss function,叫做Multiclass Support Vector Machine (SVM) loss。简称 SVM loss,下面给出该函数的定义,假设训练集第
i
个样本的输入为
请注意,由于我们这里介绍的是线性模型 f(xi,W)=Wxi ,所以我们也可以将上式重新写成:
其中,
wTj
表示
W
的第
为了进一步提升模型的稳健性,我们会引入regularization penalty,
function就包含数据误差和regularization penalty两部分,如下式所示:
展开之后得到:
通过引入regularization penalty,可以使得权值的分布更加平衡,不会单独侧重于某些局部变量。
前面我们忽略了
Δ
值的探讨,
Δ
应该选择多少比较合适?在实际应用中,我们发现把
Δ
设为1.0是非常安全的,事实上,参数
Δ,λ
都是控制loss function中数据偏差与regularization penalty之间的平衡的,因为
W
的幅值对score有直接的影响,如果我们把幅值增大,那么预测的score也会变大,反之同样成立,所以
Softmax classifier
前面介绍的SVM是线性分类器,现在我们介绍另外一种常用的非线性分类器,Softmax classifier。SVM将预测值看做是一种score,而Softmax classifier将预测值看成是一种概率,Softmax classifier的映射函数没有变化,还是
这里,我们用
fj
表示对第
j
类的预测值,与SVM一样,整个训练集的loss function将是所有样本的平均loss加上regularization误差
因此,Softmax分类器是缩小预测的每一类的概率与实际概率的cross entropy。
从概率的角度来看,我们可以看到表达式:
可以看做是给定一张图像,其属于某一类的概率,指数项给出了概率值,而分母的归一化保证概率在[0,1]之间,而且其和为1,这样我们可以引入最大似然估计去解释这个
模型,如果进一步的,我们假设
W
是属于某一特定分布,比如高斯分布,那么我们可以用最大后验概率估计去解释这个模型,这里提到这些,只是为了让大家对此有一个
直观的了解。实际编写程序的时候,由于指数运算可能会涉及到很大的值,可能会使得模型在数值上不够稳定,所以一般会引入一个常数项
C的选择没有特别地规定,可以自由选择,通常我们定义 logC=−maxjfj 。下图显示了SVM与Softmax分类器做图像分类的区别:
声明:lecture notes里的图片都来源于该课程的网站,只能用于学习,
请勿作其它用途,如需转载,请说明该课程(http://cs231n.stanford.edu/)为引用来源。