今天刚学感知机,先拿李航老师的《统计学习方法》中的 例2.1 练练手。
我这里使用的是R语言
原始形式
首先判断点是否能被正确分类,若不能则更新 w和b ,直到所有的点都能被正确分类。
1、取初值w=0,b=0;
2、若 ,
则更新
3、返回步骤2,看看是否满足,若不满足则一直训练到最后一个点都被正确分类。
代码演示 此处 所以我偷懒一下
x = cbind(c(3,3),c(4,3),c(1,1))
y = matrix(c(1,1,-1))
w = c(0,0)
b = 0
plot(x[1,],x[2,])
gzj = function(x,y,w,b){
while (sum((w%*%x+b)*t(y) <= 0) > 0 ) {
for (i in 1:length(y)) {
if (y[i]*(w%*%x[,i]+b) <= 0) {
w = w+y[i]*x[,i]
b = b+y[i]
}
}
}
return(list(w=w,b=b))
}
gzj(x,y,w,b)
最终结果
$w
[1] 1 1
$b
[1] -3
和课本的一样