感知机学习算法的对偶形式

机器学习 同时被 3 个专栏收录
7 篇文章 0 订阅
7 篇文章 0 订阅
7 篇文章 0 订阅

感知机学习算法的原始形式:http://blog.csdn.net/qq_29591261/article/details/77934696

本文相对于原文在代码中添加了自己的理解和注释,省略推理过程,想看原理推导的请参考原文:http://www.hankcs.com/ml/the-perceptron.html


关于对偶

对偶,简单地说,就是从一个不同的角度去解答相似问题,但是问题的解是相通的。
或者说原始问题比较难求解,我们去求解另外一个问题,希望通过更简单的方法得到原始问题的解。
对于感知机来说,简单来说,就是用α去记录每个yixi要加多少次,最后一次加上去就好了。
具体理解如下:
这里写图片描述
来源于知乎:https://www.zhihu.com/question/26526858


感知机学习算法的对偶形式

对偶指的是,将w和b表示为测试数据i的线性组合形式,通过求解系数得到w和b。具体说来,如果对误分类点i逐步修改wb修改了n次,则w,b关于i的增量分别为这里写图片描述,这里这里写图片描述,则最终求解到的参数分别表示为:
这里写图片描述
于是有算法2.2:
这里写图片描述


感知机对偶算法代码

  1. # -*- coding:utf-8 -*-
  2. # Filename: train2.2.py
  3. # Authorhankcs
  4. # Date: 2015/1/31 15:15
  5. import numpy as np
  6. from matplotlib import pyplot as plt
  7. from matplotlib import animation
  8.  
  9. training_set = np.array([[[3, 3], 1], [[4, 3], 1], [[1, 1], -1], [[5, 2], -1]])    #训练样本
  10.  
  11. = np.zeros(len(training_set), np.float)    #矩阵a的长度为训练集样本数,类型为float
  12. = 0.0    #参数初始值为0
  13. Gram = None    #Gram矩阵
  14. = np.array(training_set[:, 1])    #y=[1 1 -1 -1]
  15. = np.empty((len(training_set), 2), np.float)    #x4*2的矩阵
  16. for i in range(len(training_set))#x=[[3., 3.], [4., 3.], [1., 1.], [5., 2.]]
  17.     x[i] = training_set[i][0]
  18. history = []    #history记录每次迭代结果
  19.  
  20. def cal_gram():
  21.     """
  22.     计算Gram矩阵
  23.     :return:
  24.     """
  25.     g = np.empty((len(training_set), len(training_set)), np.int)
  26.     for i in range(len(training_set)):
  27.         for j in range(len(training_set)):
  28.             g[i][j] = np.dot(training_set[i][0], training_set[j][0]) #G=[xi*xj]
  29.     return g
  30.  
  31.  
  32. def update(i):
  33.     """
  34.     随机梯度下降更新参数
  35.     :param i:
  36.     :return:
  37.     """
  38.     global a, b
  39.     a[i] += 1    #根据误分类点更新参数
  40.     b = b + 1 * y[i]    #这里1是学习效率η
  41.     history.append([np.dot(* y, x), b])    #history记录每次迭代结果
  42.     print a, b    #输出每次迭代结果
  43.  
  44.  
  45. #计算yi(Gram*xi+b),用来判断是否是误分类点
  46. def cal(i):
  47.     global a, b, x, y
  48.     res = np.dot(* y, Gram[i])
  49.     res = (res + b) * y[i] #返回
  50.     return res
  51.  
  52.  
  53. #检查是否已经正确分类
  54. def check():
  55.     global a, b, x, y
  56.     flag = False
  57.     for i in range(len(training_set)):    #遍历每个点
  58.         if cal(i) <= 0:    #如果yi(Gram*xi+b)<=0.则是误分类点
  59.             flag = True
  60.             update(i)    #用误分类点更新参数
  61.     if not flag: #如果已正确分类
  62.         w = np.dot(* y, x)    #计算w
  63.         print "RESULT: w: " + str(w) + " b:" + str(b)    #输出最后结果
  64.         return False
  65.     return True
  66.  
  67.  
  68. if __name__ == "__main__":
  69.     Gram = cal_gram()    #初始化 Gram矩阵
  70.     for i in range(1000):    #迭代1000
  71.         if not check()break    #如果已正确分类则结束循环
  72.  
  73.     #以下代码是将迭代过程可视化,数据来源于history
  74.     # first set up the figure, the axis, and the plotelement we want to animate
  75.     fig = plt.figure()
  76.     ax = plt.axes(xlim=(0, 2), ylim=(-2, 2))
  77.     line, = ax.plot([], [], 'g', lw=2)
  78.     label = ax.text([], [], '')
  79.  
  80.     # initialization function: plot the background of eachframe
  81.     def init():
  82.         line.set_data([], [])
  83.         x, y, x_, y_ = [], [], [], []
  84.         for p in training_set:
  85.             if p[1] > 0:
  86.                 x.append(p[0][0])
  87.                 y.append(p[0][1])
  88.             else:
  89.                 x_.append(p[0][0])
  90.                 y_.append(p[0][1])
  91.  
  92.         plt.plot(x, y, 'bo', x_, y_, 'rx')
  93.         plt.axis([-6, 6, -6, 6])
  94.         plt.grid(True)
  95.         plt.xlabel('x')
  96.         plt.ylabel('y')
  97.         plt.title('PerceptronAlgorithm 2 (www.hankcs.com)')
  98.         return line, label
  99.  
  100.  
  101.     # animation function. this is called sequentially
  102.     def animate(i):
  103.         global history, ax, line, label
  104.  
  105.         w = history[i][0]
  106.         b = history[i][1]
  107.         if w[1] == 0return line, label
  108.         x1 = -7.0
  109.         y1 = -(+ w[0] * x1) / w[1]
  110.         x2 = 7.0
  111.         y2 = -(+ w[0] * x2) / w[1]
  112.         line.set_data([x1, x2], [y1, y2])
  113.         x1 = 0.0
  114.         y1 = -(+ w[0] * x1) / w[1]
  115.         label.set_text(str(history[i][0]) + ' ' + str(b))
  116.         label.set_position([x1, y1])
  117.         return line, label
  118.  
  119.     # call the animator. blit=true means only re-draw the parts that have changed.
  120.     anim =animation.FuncAnimation(fig, animate, init_func=init, frames=len(history), interval=1000, repeat=True,
  121.                                    blit=True)
  122.     plt.show()
  123.     #anim.save('D:/perceptron2.gif',fps=2, writer='imagemagick')

    运行结果

    1. [ 1.  0.  0.  0.]1.0
    2. [ 1.  0.  1.  0.]0.0
    3. [ 1.  0.  1.  1.]-1.0
    4. [ 2.  0.  1.  1.]0.0
    5. [ 2.  0.  2.  1.]-1.0
    6. [ 2.  0.  3.  1.]-2.0
    7. [ 3.  0.  3.  1.]-1.0
    8. [ 3.  0.  4.  1.]-2.0
    9. [ 3.  0.  4.  2.]-3.0
    10. [ 4.  0.  4.  2.]-2.0
    11. [ 4.  0.  5.  2.]-3.0
    12. [ 5.  0.  5.  2.]-2.0
    13. [ 5.  0.  6.  2.]-3.0
    14. [ 5.  0.  6.  3.]-4.0
    15. [ 6.  0.  6.  3.]-3.0
    16. [ 6.  0.  7.  3.]-4.0
    17. [ 7.  0.  7.  3.]-3.0
    18. [ 7.  0.  8.  3.]-4.0
    19. [ 7.  0.  8.  4.]-5.0
    20. [ 8.  0.  8.  4.]-4.0
    21. [ 8.  0.  9.  4.]-5.0
    22. [ 8.  1.  9.  4.]-4.0
    23. [  8.   1. 10.   4.] -5.0
    24. [  8.   1. 10.   5.] -6.0
    25. [  9.   1. 10.   5.] -5.0
    26. [  9.   1. 11.   5.] -6.0
    27. RESULT: w: [-5.0 9.0] b:-6.0

      可视化

      这里写图片描述

      • 14
        点赞
      • 3
        评论
      • 28
        收藏
      • 一键三连
        一键三连
      • 扫一扫,分享海报

      ©️2021 CSDN 皮肤主题: 技术黑板 设计师:CSDN官方博客 返回首页
      实付
      使用余额支付
      点击重新获取
      扫码支付
      钱包余额 0

      抵扣说明:

      1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
      2.余额无法直接购买下载,可以购买VIP、C币套餐、付费专栏及课程。

      余额充值