最近开始学习统计学习方法(李航),作为小白入门统计学习,希望用博客记录自己的学习历程,同时与大家分享自己的代码。希望看到这篇博客的各位程序大佬们能够指出我的错误!感知机模型代码如下:
import numpy as np
# 感知机模型
def perceptron(x, y, a0, b0, eta):
n = len(y)
a, b = a0, b0
Gram = np.dot(x, x.T) # Gram矩阵
niter = 0 # 迭代计数
while True:
i = 0
while i < n:
if y[i] * (sum(a*y*Gram[i, :]) + b) <= 0:
# 更新参数
a[i] = a[i] + eta
b = b + eta * y[i]
niter += 1
break
else:
i += 1
if i == n: # 迭代终止条件
w = np.dot(a * y, x)
break
return w, b, niter
if __name__ == "__main__":
# 数据及参数
x = np.array([[3, 3],
[4, 3],
[1, 1]])
y = np.array([1, 1, -1])
a0 = np.zeros(len(y))
b0 = 0
eta = 1 # 学习率
# 感知机训练
w, b, niter = perceptron(x, y, a0, b0, eta)
print("系数为:", w, b)
print("迭代次数为:", niter)