机器学习基石---第二周PLA

knitr::opts_chunk$set(echo = TRUE)

  台大《机器学习基石》第二周课的笔记,只整理部分重要内容。希望能把课上学的,做一个精简的记录。

变量说明

  存在两类数据,标记为 y ,取值为1,1。特征向量记为 x x=(x0,x1,x2,...,xd)。其中 x0 为常量1,其余为具体特征值。存在超平面 wTx=0 ,其中 w=(w0,w1,...,wd) ,可以正确分开两类数据。共有 N 个样本数据。

迭代过程

  PLA采取知错就改的策略。遍历所有样本,如果发现分类错误,采用如下方式如下方式更新w
  

Fort=0,1,...N1.findamistakeofwtcalled(xn(t),yn(t))sign(wTtxn(t))yn(t)2.(tryto)correctthemistakebywt+1wt+yn(t)xn(t)...untilnomoremistakesreturnlastw(calledwPLA)asg

更新理由

这里写图片描述
  判断类别的公式:

sign(wTtxn(t))=sign(wTtxn(t)cos(θ))

  如果正类被误判,则 cos(θ)<0 ,即 θ(π2,π) ,所以要缩小法向量和特征向量之间的夹角。故采用上图方法迭代 w 的值。

证明

  证明线性可分数据集,PLA算法一定能够经过有限次的迭代,得到一个完美的分割超平面。

每一次迭代wt更接近 wf

  1. wf 为完美分类器
  2. (xn,yn) 为错分的样本
  3. (xn(t),yn(t)) 为第t次迭代时, wt 错分的样本
  因为 wf 是完美分类器,则一定有:

yn(t)wTfxn(t)minnynwTfxn>0

  利用任意一个错判样本 (xn(t),yn(t)) 进行第 t+1 次迭代之后,计算:

wTfwt+1wTfwt+1=wTf(wt+yn(t)xn(t))wTfwt+1=wTfwt+yn(t)wTfxn(t)wTfwt+1wTfwt+minnyn(t)wTfxn(t)wTfwt+1>wTfwt+0wTfwt+1=wTfwtwTfwt+1

  从余弦相似度的角度看,通过错判样本对 wt 的修正,使得迭代后的 w 更接近于完美的分割超平面。

每一次迭代wt的模增长较小

wt+12=wt+yn(t)xn(t)2=wt2+2yn(t)wTtxn(t)+yn(t)xn(t)2wt2+0+yn(t)xn(t)2wt2+maxnynxn2

迭代次数有限

  假设 w0=0 ,经过 T 次迭代之后:

wTfwTwfwT=wTf(wT1+yn(T1)xn(T1))wfwT=wTf(wT1+yn(T1)xn(T1))wfwT=wTfwT1+yn(T1)wTfxn(T1)wfwTwTfwT1+minnynwTfxnwfwTwTfwT2+yn(T2)wTfxn(T2)+minnynwTfxnwfwTwTfwT2+2minnynwTfxnwfwTTminnynwTfxnwfwTFurther:wTfwTTminnynwTfxnTwTfwTminnynwTfxnT2(wTfwT)2(minnynwTfxn)2=wf2wT2sin2(θ)(minnynwTfxn)2wf2wT2(minnynwTfxn)2wf2maxynxnn2(minnynwTfxn)2=wf2maxxnn2(minnynwTfxn)2

  所以迭代次数 T <script type="math/tex" id="MathJax-Element-33">T</script>有上界。

案例

构造数据集

  构造数据集,验证算法。

x11 <- 1:10
x21 <- x11 + runif(10, 0, 1) + 3
x22 <- x11 - runif(10, 0, 1)
example_data <- data.frame(x1 = rep(x11, 2),
x2 = c(x21, x22),
label = rep(c(1, -1), each = 10))
example_data$label <- as.factor(example_data$label)
library(ggplot2)
ggplot(data = example_data, aes(
x = x1,
y = x2,
color = label,
shape = label
)) +
geom_point()

这里写图片描述

PLA算法

## 参数:数据集、标签名称

PLA_f <- function(dataset, label) {
  ## 样本数
  row_num <-  nrow(dataset)
  w <- rep(1, ncol(dataset))
  w0 <- matrix(w, 1, 3, byrow = T)
  real_label <- as.numeric(as.vector(dataset[, label]))
  feature_matrix <-
    as.matrix(data.frame(x0 = rep(1, row_num), cbind(dataset[, setdiff(colnames(dataset), label)])))
  i <- 1
  j <- 0
  while (i < row_num & j == 0) {
    i <- 1
    j <- 0
    for (i in 1:row_num) {
      ## 判断是否有误判
      if (as.vector(feature_matrix[i,] %*% t(w0)) * real_label[i] <= 0) {
        ## 存在误判,修正w0
        w0 <- w0 + real_label[i] * feature_matrix[i,]
        w <- c(w, w0)
        j <- 1
      }
      if(j == 1){
        j <- 0
        i <- row_num-1
        break()}
    }
  }
  w_data <- data.frame(matrix(w,ncol=ncol(dataset),byrow = TRUE))
  colnames(w_data) <- paste0("x",0:(ncol(feature_matrix)-1))
  w_data <- dplyr::mutate(w_data,
                          slope = -x1 / x2,
                          intercept = -x0 / x2)
  return(w_data)
}

求解

w_data <- PLA_f(dataset = example_data, label = "label")
w_data
   x0 x1           x2        slope    intercept
1   1  1  1.000000000   -1.0000000   -1.0000000
2   0  0  0.495471116    0.0000000    0.0000000
3  -1 -1 -0.009057768 -110.4024725 -110.4024725
4   0  0  4.912654036    0.0000000    0.0000000
5  -1 -1  4.408125152    0.2268538    0.2268538
6  -2 -2  3.903596268    0.5123481    0.5123481
7  -3 -4  1.915120282    2.0886417    1.5664812
8  -2 -1  8.363856425    0.1195621    0.2391241
9  -3 -2  7.859327541    0.2544747    0.3817120
10 -4 -4  5.870851555    0.6813322    0.6813322
11 -5 -9  1.747566727    5.1500179    2.8611211
12 -4 -8  6.669278532    1.1995300    0.5997650

动图

library(animation)
## 指定ImageMagic目录位置,注意是magick.exe,之前版本貌似一致是convert.exe
ani.options(convert = "D:/ImageMagic/ImageMagick-7.0.7-Q16/magick.exe")
saveGIF(
  expr = {
    library(ggplot2)
    for (i in 1:nrow(w_data)) {plot(
      x = example_data$x1[1:10],
      y = example_data$x2[1:10],
      pch = 15,
      col = "red",
      xlim = c(0, 20),
      ylim = c(0, 15),
      xlab = "x1",
      ylab = "x2",main = paste0("Picture",i)
    )
      lines(x = example_data$x1[11:20],
            y = example_data$x2[11:20],
            type = "p",
            pch = 17,
            col = "blue")
      abline(coef=c(w_data$intercept[i],w_data$slope[i]),lwd=2)
      }
  },
  ## GIF文件名,注意文件后缀名要加上
  movie.name = "PLA.gif",
  ## 时间间隔
  interval = 1,
  ## 图形设置
  ani.width = 600,
  ani.height = 600,
  ## 文件输出在当前目录
  outdir = getwd()
)

这里写图片描述

Ref

[1]课程PPT

2017-12-19于杭州

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值