本文是本人在学习台大林轩田教授的《机器学习基石》第二课 PLA 相关内容后的学习笔记。文中截图来自于林轩田教授的课件。
本文同步发布在本人的微信公众号,鉴于格式原因(csdn上下标不知道咋整),最好看微信公众号上的。
目录
1. PLA 简介
2. PLA 工作原理
3. 代码实现
4. 证明PLA一定收敛
5. Pocket PLA
1. PLA 简介
PLA 即 Perceptron Learning Algorithm(感知器学习算法),是一种二分类算法,它可以用于处理二维及更高维度数据集的二分类问题。通常分为正类(+1)和负类(-1),其假设函数(Hypothesis)如下:
由于-threshold = (-threshold)*(+1),所以上述式子可以改写成如下形式:
这样一改写就使得整个式子看起来更加简洁,直接使用一个向量的内积运算即可。
我们的目的就是找到一个超平面 wTx,使其能够恰好将正负两类分隔开,那么我们应该怎么去找这个 w 呢?看第二节!
2. PLA 工作原理
其工作原理可以用一句话概括:知错能改,善莫大焉!
没错,PLA 算法只需我们着眼于被错分的点,某个点分错了(本来是正类,分到了负类一边,或者反之),咱们就去纠正它的错误。
犯错无非两种情况:将正类错分到了负类的一边,或将负类错分到了正类的一边。下面咱们分情况讨论。
对于第一种情况:正类错分到负类。即:对于某被错分的点(x,y),我们想要的是 wTx > 0,而实际却是 wTx < 0,我们知道,如果两个向量内积小于0,即表示:这两个向量之间的夹角是大于90度的,所以咱们要使它们的夹角小于90度,使 wTx > 0,所以我们可以通过 wt+1←wt + (+1)x 更新 w,由于该错分的点 y=(+1),所以 wt+1← wt + yx,如下图所示:
对于第二种情况:负类错分到正类。即:对于某被错分的点(x,y),我们想要的是 wTx < 0,而实际却是 wTx > 0,我们知道,如果两个向量内积大于0,即表示:这两个向量之间的夹角是小于90度的,所以咱们要使它们的夹角大于90度,使 wTx < 0,所以我们可以通过 wt+1←wt + (-1)x 更新 w,由于该错分的点 y=(-1),所以 wt+1← wt + yx,如下图所示:
因此,不论对应哪种情况,对于 w 的更新,始终都是:wt+1← wt + yx。
3. 代码实现
在代码实现时,我们只需不断检测是否有被错分的点,如果有,矫正之,直到所有的点都在正确的一边。下面是我的代码实现:
自己瞎写了几个数据(存于 txt 文件中):
1 1 +1
2 2 +1
2 1 +1
2.6 0 +1
3.3 1.4 +1
3.4 2.6 +1
1.6 0 +1
0 0.9 -1
-1.2 0 -1
-1.8 -0.8 -1
-1.2 1.2 -1
-0.7 1.8 -1
0 3 -1
0 3.8 -1
0.6 4.2 -1
-1 3.3 -1
-2.2 2 -1
-1.4 2.3 -1
下面是主要代码,也就是 PLA 算法的实现,可见非常简单。w 初始化为0,判断某个点(x,y)是否被错分的条件是:w 和 x 的内积乘以它的标签 y 是否小于等于0,如果是,则表示其被错分了,需要更新 w,否则跳过。
def pla(X, y):
w = np.zeros(X.shape[1])
# 记录过程w(每一轮调整后的w)
ws = []
ws.append(w)
# 调整次数
adjude_num = 0
while True:
over = True
for i in range(X.shape[0]):
if np.dot(X[i, :], w) * y[i] <= 0:
w = w + y[i] * X[i, :].T
ws.append(w)
over = False
adjude_num+=1
else:
continue
if over:
break
return w,ws,adjude_num
为了更直观地看到 w 的更新过程,使用 pyplot 画出数据散点图及分割线,下面是画图的代码:
# 画出数据散点图
def plot_scatter(data, w, fig):
# 正例
x_p = []
# 反例
x_n = []
for l in data:
if l[2] == 1.:
x_p.append([l[0], l[1]])
else:
x_n.append([l[0], l[1]])
# 直线
x1 = np.linspace(-5,5,50)
if w[2] == 0:
x2 = 0*x1
else:
x2 = -w[0]/w[2]-w[1]*x1/w[2]
plt.figure(fig)
x_p = np.array(x_p)
x_n = np.array(x_n)
# 限制坐标轴的范围
plt.xlim(-3,4)
plt.ylim(-2,5)
plt.scatter(x_p[:,0], x_p[:,1], marker='o')
plt.scatter(x_n[:,0], x_n[:,1], marker='x')
plt.plot(x1,x2)
# plt.show()
执行结果如下:
w 变化过程如下:
4. 证明 PLA 一定收敛
PLA 是通过不停地纠正错分点,一点点调整 w,直到找到一个可以恰好完美地将数据集分开的超平面。我们从直观上看,只要数据集是线性可分的,就一定可以找到一个符合条件的超平面,算法就会停止,但是数学讲究的是严谨,需要有严格的推导证明去证实。
前提:数据集线性可分。
先定义几个符号:
wf | 我们想要的那个完美的 w |
t | 轮数 |
wt+1 | 第 t+1 轮的 w |
wt | 第 t 轮的 w |
(xn,yn) | 数据集中的某一个数据(某一个点) |
数据集线性可分 <=> 存在一个完美的 wf 使 yn= sign(wfTxn)
即,对于任意一个点(xn,yn),有 ynwfTxn≥ minn(ynwfTxn)> 0
又因为:wfTwt+1 =wfT(wt+ynxn)= wfTwt + ynwfTxn≥ wfTwt + minn(ynwfTxn)
进一步可以得到:wfTwt+1≥wfTwt + minn(ynwfTxn)≥ wfTwt-1 + 2*minn(ynwfTxn)…… ≥ w0 + t*minn(ynwfTxn)
w0 为初始化的 w,假设 w0 = 0(w 可以初始化为任意值,这里初始化为0方便证明)
所以得到:wfTwt+1≥(t+1)*minn(ynwfTxn)
即:wfTwt≥t*minn(ynwfTxn) (式①)
又:||wt+1||2= ||wt + ynxn||2 = ||wt||2+ 2ynwtTxn + ||ynxn||2
因为:ynwtTxn ≤ 0
所以:||wt+1||2= ||wt + ynxn||2 = ||wt||2 + ||ynxn||2≤ ||wt||2 + maxn(||ynxn||2)
即:||wt+1||2≤ ||wt||2 + maxn(||ynxn||2)≤ ||wt-1||2 + 2* maxn(||ynxn||2)…… ≤ ||w0||2 + (t+1)* maxn(||ynxn||2)= (t+1)* maxn(||ynxn||2)
即:||wt||2≤ t* maxn(||ynxn||2)
因为||yn|| = 1
所以进一步:||wt||2≤ t* maxn(||xn||2)
再进一步:||wt||≤ √t * max|| xn || (式②)
最后综合式 ①②:||wt||≤ √t * max|| xn ||
wfTwt≥ t * minn(ynwfTxn)
最后的计算过程实在不想敲了…… 就手写了
如下图,写的很丑,不过应该能看清!
所以,经过上面的计算,我们得到 t 小于等于一个常数,即轮数是有上限的,即 PLA 一定会停止。
证毕。
5. Pocket PLA
现实中,很少有那种能用一个超平面完美地将所有数据分开的情况,经常需要考虑到噪声(Noisy data)。如果数据集存在噪点,那么 PLA 算法就失灵了,因为它永远无法将所有的点用一个线性模型分开,它就永远不会停下来。这个时候就可以使用一种改进的 PLA —— Pocket PLA。
这里简单说一下它的思路就好,顾名思义,Pocket 即口袋,它是一种贪心策略,口袋里存放当前最合适的那个 w,如果有一个新的 w’,它犯的错更少(即错分的点更少),则用 w’ 更新 w,否则保持 w。在实现中,我们需要指定一个迭代轮数,否则算法不知道什么时候停止。