PLA 学习笔记

本文是本人在学习台大林轩田教授的《机器学习基石》第二课 PLA 相关内容后的学习笔记。文中截图来自于林轩田教授的课件。

 

本文同步发布在本人的微信公众号,鉴于格式原因(csdn上下标不知道咋整),最好看微信公众号上的。

 

目录

        1. PLA 简介

        2. PLA 工作原理

        3. 代码实现

        4. 证明PLA一定收敛

        5. Pocket PLA

 

1. PLA 简介

PLA 即 Perceptron Learning Algorithm(感知器学习算法),是一种二分类算法,它可以用于处理二维及更高维度数据集的二分类问题。通常分为正类(+1)和负类(-1),其假设函数(Hypothesis)如下:

 

 

由于-threshold = (-threshold)*(+1),所以上述式子可以改写成如下形式:

 

 

这样一改写就使得整个式子看起来更加简洁,直接使用一个向量的内积运算即可。

 

我们的目的就是找到一个超平面 wTx,使其能够恰好将正负两类分隔开,那么我们应该怎么去找这个 呢?看第二节!

 

2. PLA 工作原理

其工作原理可以用一句话概括:知错能改,善莫大焉!

没错,PLA 算法只需我们着眼于被错分的点,某个点分错了(本来是正类,分到了负类一边,或者反之),咱们就去纠正它的错误。

 

犯错无非两种情况:将正类错分到了负类的一边,或将负类错分到了正类的一边。下面咱们分情况讨论。

 

对于第一种情况:正类错分到负类。即:对于某被错分的点(x,y),我们想要的是 wTx > 0,而实际却是 wTx < 0,我们知道,如果两个向量内积小于0,即表示:这两个向量之间的夹角是大于90度的,所以咱们要使它们的夹角小于90度,使 wTx > 0,所以我们可以通过 wt+1←wt + (+1)x 更新 w,由于该错分的点 y=(+1),所以 wt+1← wt + yx,如下图所示:

 

对于第二种情况:负类错分到正类。即:对于某被错分的点(x,y),我们想要的是 wTx < 0,而实际却是 wTx > 0,我们知道,如果两个向量内积大于0,即表示:这两个向量之间的夹角是小于90度的,所以咱们要使它们的夹角大于90度,使 wTx < 0,所以我们可以通过 wt+1←wt + (-1)x 更新 w,由于该错分的点 y=(-1),所以 wt+1← wt + yx,如下图所示:

 

 

因此,不论对应哪种情况,对于 的更新,始终都是:wt+1← wt + yx

 

3. 代码实现 

在代码实现时,我们只需不断检测是否有被错分的点,如果有,矫正之,直到所有的点都在正确的一边。下面是我的代码实现:

自己瞎写了几个数据(存于 txt 文件中):

 


1 1 +1
2 2 +1
2 1 +1
2.6 0 +1
3.3 1.4 +1
3.4 2.6 +1
1.6 0 +1
0 0.9 -1
-1.2 0 -1
-1.8 -0.8 -1
-1.2 1.2 -1
-0.7 1.8 -1
0 3 -1
0 3.8 -1
0.6 4.2 -1
-1 3.3 -1
-2.2 2 -1
-1.4 2.3 -1

 

下面是主要代码,也就是 PLA 算法的实现,可见非常简单。初始化为0,判断某个点(x,y)是否被错分的条件是:和 的内积乘以它的标签 y 是否小于等于0,如果是,则表示其被错分了,需要更新 w,否则跳过。

 

def pla(X, y):
    w = np.zeros(X.shape[1])

    # 记录过程w(每一轮调整后的w)
    ws = []
    ws.append(w)

    # 调整次数
    adjude_num = 0

    while True:
        over = True
        for i in range(X.shape[0]):
            if np.dot(X[i, :], w) * y[i] <= 0:
                w = w + y[i] * X[i, :].T
                ws.append(w)
                over = False
                adjude_num+=1
            else:
                continue
        if over:
            break

    return w,ws,adjude_num

 

为了更直观地看到 的更新过程,使用 pyplot 画出数据散点图及分割线,下面是画图的代码:

 

# 画出数据散点图
def plot_scatter(data, w, fig):
    # 正例
    x_p = []
    # 反例
    x_n = []
    for l in data:
        if l[2] == 1.:
            x_p.append([l[0], l[1]])
        else:
            x_n.append([l[0], l[1]])

    # 直线
    x1 = np.linspace(-5,5,50)
    if w[2] == 0:
        x2 = 0*x1
    else:
        x2 = -w[0]/w[2]-w[1]*x1/w[2]

    plt.figure(fig)

    x_p = np.array(x_p)
    x_n = np.array(x_n)

    # 限制坐标轴的范围
    plt.xlim(-3,4)
    plt.ylim(-2,5)
    plt.scatter(x_p[:,0], x_p[:,1], marker='o')
    plt.scatter(x_n[:,0], x_n[:,1], marker='x')
    plt.plot(x1,x2)
   # plt.show()

 

执行结果如下:

 

 

w 变化过程如下:

 

4. 证明 PLA 一定收敛

PLA 是通过不停地纠正错分点,一点点调整 w,直到找到一个可以恰好完美地将数据集分开的超平面。我们从直观上看,只要数据集是线性可分的,就一定可以找到一个符合条件的超平面,算法就会停止,但是数学讲究的是严谨,需要有严格的推导证明去证实。

 

前提:数据集线性可分。

 

先定义几个符号:

wf我们想要的那个完美的 w
t轮数
wt+1第 t+1 轮的 w
wt第 t 轮的 w
(xn,yn)数据集中的某一个数据(某一个点)

数据集线性可分 <=> 存在一个完美的 wf 使 yn= sign(wfTxn)

即,对于任意一个点(xn,yn),有 ynwfTxn≥ minn(ynwfTxn)> 0

 

又因为:wfTwt+1 =wfT(wt+ynxn)= wfTwt + ynwfTxn≥ wfTwt + minn(ynwfTxn)

进一步可以得到:wfTwt+1≥wfTwt + minn(ynwfTxn)≥ wfTwt-1 + 2*minn(ynwfTxn)…… ≥ w0 + t*minn(ynwfTxn)

 

w0 为初始化的 w,假设 w0 = 0可以初始化为任意值,这里初始化为0方便证明)

 

所以得到:wfTwt+1≥(t+1)*minn(ynwfTxn)

即:wfTwt≥t*minn(ynwfTxn) (式①)

 

又:||wt+1||2= ||wt + ynxn||2 = ||wt||2+ 2ynwtTxn + ||ynxn||2

因为:ynwtTxn ≤ 0

所以:||wt+1||2= ||wt + ynxn||2 = ||wt||2 + ||ynxn||2≤ ||wt||2 + maxn(||ynxn||2)

 

即:||wt+1||2≤ ||wt||2 + maxn(||ynxn||2)≤ ||wt-1||2 + 2* maxn(||ynxn||2)…… ≤ ||w0||2 + (t+1)* maxn(||ynxn||2)= (t+1)* maxn(||ynxn||2)

即:||wt||2≤ t* maxn(||ynxn||2)

因为||yn|| = 1

所以进一步:||wt||2≤ t* maxn(||xn||2)

再进一步:||wt||≤ √t * max|| xn ||  (式②)

 

最后综合式 ①②:||wt||≤ √t * max|| xn ||

           wfTwt≥ t * minn(ynwfTxn)

 

最后的计算过程实在不想敲了…… 就手写了

如下图,写的很丑,不过应该能看清!

 

 

所以,经过上面的计算,我们得到 t 小于等于一个常数,即轮数是有上限的,即 PLA 一定会停止。

 

证毕。

 

 

5. Pocket PLA

现实中,很少有那种能用一个超平面完美地将所有数据分开的情况,经常需要考虑到噪声(Noisy data)。如果数据集存在噪点,那么 PLA 算法就失灵了,因为它永远无法将所有的点用一个线性模型分开,它就永远不会停下来。这个时候就可以使用一种改进的 PLA —— Pocket PLA。

 

这里简单说一下它的思路就好,顾名思义,Pocket 即口袋,它是一种贪心策略,口袋里存放当前最合适的那个 w,如果有一个新的 w’,它犯的错更少(即错分的点更少),则用 w’ 更新 w,否则保持 w。在实现中,我们需要指定一个迭代轮数,否则算法不知道什么时候停止。

  • 0
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值