更新日志
-
更新时间:2020年10月10日
- 更新内容:最近在重新看SVM和SVR部分,对于分类问题的原理有了一个更清晰的认知,在此记录一下。
二分类问题
如上图所示的一个线性可分的二分类问题,红点表示为一类,蓝点表示为另一类,我们的目标是找到一条直线讲两类点完全分开。从直观上讲,我们很容易直接在图中画出这样一条直线,使得同一类的点都在直线的同一侧。接下来我们来看一下这样画出的直线有什么性质。
设
A
(
x
1
,
y
1
)
A(x_1,y_1)
A(x1,y1)为红类中的任意一点,
B
(
x
2
,
y
2
)
B(x_2, y_2)
B(x2,y2)为蓝类中的任意一点。分别过点A和B做
x
x
x轴的垂线和分界直线分别交于
C
(
x
1
,
y
3
)
C(x_1,y_3)
C(x1,y3)和
D
(
x
2
,
y
4
)
D(x_2,y_4)
D(x2,y4)。
设直线方程为
a
x
+
b
y
+
c
=
0
ax+by+c=0
ax+by+c=0。我们有
a
x
1
+
b
y
3
+
c
=
0
ax_1+by_3+c=0
ax1+by3+c=0
a
x
2
+
b
y
4
+
c
=
0
ax_2+by_4+c=0
ax2+by4+c=0
且
y
1
<
y
3
,
y
2
>
y
4
y_1<y_3, y_2>y_4
y1<y3,y2>y4,分别把A,B点坐标代入直线方程,有
a
x
1
+
b
y
1
+
c
<
0
ax_1+by_1+c<0
ax1+by1+c<0
a
x
2
+
b
y
2
+
c
>
0
ax_2+by_2+c>0
ax2+by2+c>0
于是我们发现,对于直线同侧的点代入直线方程 a x + b y + c ax+by+c ax+by+c,其结果总是同号的,而直线异侧的点代入其结果总是异号的。 而直线上的点代入总有 a x + b y + c = 0 ax+by+c=0 ax+by+c=0。
我们将点的坐标 ( x , y ) (x,y) (x,y)看成一个二维向量 X X X,上述直线方程又可以表示为 W T X + b = 0 W^TX+b=0 WTX+b=0,这里的 W W W也是一个和 X X X同维度的向量,并且该式子同样能描述 X X X为多维的情况。
因此我们可以将二分类问题转化为数学描述:假设有两种类别的点(可以是多维的),我们要找到一条直线 W T X + b = 0 W^TX+b=0 WTX+b=0(多维时为超平面)将它们分割开来,使得同类点代入该直线方程结果总是同号的,异类点代入直线方程结果总是异号的。
我们以符号来区分两类样本,正类样本标记为1,负类样本标记为-1。要求的是决策边界,也就是上面的直线,即直线的参数为
w
w
w和
b
b
b。预测时,我们将待测样本点代入直线方程,结果为正即为正类样本,结果为负即为负类样本。
不同的线性分类器的目标是一致的,都是要求决策边界。主要区别在于损失函数的不同。上图中描述的代价函数为感知器模型的代价函数。
感知器的目标是让误分类点到决策边界的距离尽可能小。 据此,给出上述的损失函数,即当分类正确时,损失为0;分类错误时,损失为 − y ( w x + b ) -y(wx+b) −y(wx+b)。(仔细分析损失函数,当分类错误时, y y y和 w x + b wx+b wx+b是异号的,损失函数前面添加了一个负号,保证了损失总是大于等于0的。然后通过最小化损失函数来求解参数。)当然通过求误分类点到决策边界的距离同样可以推导出上述损失函数。
梯度下降法求解参数:
主要思想就是,给定损失函数,要让损失函数最小化,只需要让参数每次往梯度下降的方向走一小步,这样总能到达一个极小值点。梯度下降法最重要的部分就是求解损失函数关于参数的梯度。
上述感知器的损失函数
L
=
−
y
(
w
x
+
b
)
L=-y(wx+b)
L=−y(wx+b)对于参数的梯度为:
∂
L
∂
w
=
−
y
x
\frac{\partial L}{\partial w}=-yx
∂w∂L=−yx
∂
L
∂
b
=
−
y
\frac{\partial L}{\partial b}=-y
∂b∂L=−y
因此对于误分类样本,参数更新方式为
w
=
w
−
(
−
y
x
)
w=w-(-yx)
w=w−(−yx)
b
=
b
−
(
−
y
)
b=b-(-y)
b=b−(−y)
常用的梯度下降法有三种:
批梯度下降法:每次计算完全部训练样本的loss以后,再更新参数
随机梯度下降法:每次使用一个样本更新参数
小批量梯度下降法:每次使用一个batch的样本更新参数
感知器
定义感知器模型:
f
(
x
)
=
s
i
g
n
(
w
⋅
x
+
b
)
f(x)=sign(w \cdot x+b)
f(x)=sign(w⋅x+b)
其中 s i g n sign sign为符号函数,即当 w ⋅ x + b > 0 w \cdot x+b >0 w⋅x+b>0时 f ( x ) = 1 f(x)=1 f(x)=1;当 w ⋅ x + b < 0 w \cdot x+b<0 w⋅x+b<0时 f ( x ) = − 1 f(x)=-1 f(x)=−1。
定义损失函数为误分类点到超平面的总距离:
min
w
,
b
L
(
w
,
b
)
=
∑
x
i
∈
M
−
y
i
(
w
⋅
x
i
+
b
)
\min \limits_{w,b} L(w,b)=\sum_{x_i \in M} -y_i(w \cdot x_i +b)
w,bminL(w,b)=xi∈M∑−yi(w⋅xi+b)
其中 M M M为误分类点的集合。
参考资料:
[1] 新手入门:感知器
[2] Coursera机器学习基石 第2讲:感知器
LinearSVM
参考资料:
[1] Python · SVM(二)· LinearSVM