感知机原理参考博客:
【机器学习】感知机原理详解
感知机模型: f ( x ) = s i g n ( w ∗ x + b ) f(x)=sign(w*x+b) f(x)=sign(w∗x+b)
s
i
g
n
sign
sign是符号函数
感知机模型的其中一个超平面是:
w
∗
x
+
b
=
0
w*x+b=0
w∗x+b=0
w w w是超平面的法向量, b b b是超平面的截距
这个超平面把样本分为正负两类(结合 s i g n ( ) sign() sign()函数)
主要思路:
1.输入训练数据集:
编号 | 宽度 | 长度 | 检验类别 |
---|---|---|---|
1 | 3 | 3 | 正品 |
2 | 4 | 3 | 正品 |
3 | 1 | 1 | 次品 |
将正负样本区分开
x = np.array([[3,3],[4,3],[1,1]],dtype=np.float64)
y = np.array([1,1,-1],dtype=np.float64)
l=len(x)
x_positive = []
x_negetive = []
for i in range(l):
if y[i]==1:
x_positive.append(x[i])
else:
x_negetive.append(x[i])
x_positive = np.array(x_positive)
x_negetive = np.array(x_negetive)
2.赋值 w 0 , b 0 w_0,b_0 w0,b0
w=np.array([0.0,0.0])
b=np.array([0.0])
lr = 0.5
square = lambda x:x*x
def sign(v):
if v>=0:
return 1
else:
return -1
flag=0
3.选取数据点 ( x i , y i ) (x_i,y_i) (xi,yi)
4.判断该数据点是否为当前模型的误分类点,并进行更新
y i ( w ∗ x i + b ) < = 0 y_i(w*x_i+b)<=0 yi(w∗xi+b)<=0时更新
while flag==0:
for i in range(3):
y_pred = sign(np.matmul(w, x[i])+b)
if y[i] * y_pred < 0: #是误分类点。
partial_w = (-y[i]*x[i])
partial_b = (-y[i])
#更新:
w = w-lr*partial_w
b = b-lr*partial_b
5.若为误分类点则转到第三步,直到没有误分类点为止
完整代码如下:
import time
import numpy as np
import matplotlib.pyplot as plt
#训练数据
x = np.array([[3,3],[4,3],[1,1]],dtype=np.float64)
y = np.array([1,1,-1],dtype=np.float64)
l=len(x)
x_positive = []
x_negetive = []
for i in range(l):
if y[i]==1:
x_positive.append(x[i])
else:
x_negetive.append(x[i])
x_positive = np.array(x_positive)
x_negetive = np.array(x_negetive)
#参数定义
w=np.array([0.0,0.0])
b=np.array([0.0])
lr = 0.5
square = lambda x:x*x
def sign(v):
if v>=0:
return 1
else:
return -1
flag=0
#感知机学习算法是有误分类驱动的
while flag==0:
for i in range(3):
y_pred = sign(np.matmul(w, x[i])+b)
if y[i] * y_pred < 0: #是误分类点。
partial_w = (-y[i]*x[i])
partial_b = (-y[i])
#更新:
w = w-lr*partial_w
b = b-lr*partial_b
#画图
if w[1]==0:
continue
plt.xlim(0,10)
plt.ylim(-5,5)
plt.scatter(x_positive[:,0],x_positive[:,1],color='blue')
plt.scatter(x_negetive[:,0],x_negetive[:,1],color='red')
px=np.linspace(0,10,10)
py=(-w[0]/w[1])*px-b/w[1]
plt.plot(px,py)
plt.pause(0.1)
plt.clf()
#检验
rst=[]
for i in range(3):
if y[i]*sign(np.matmul(w, x[i])+b)>0:
#分类正确,则记为1
rst.append(1)
else:
rst.append(0)
#如果全都分类正确
if min(rst)==1:
flag=1
time.sleep(3)
break
print('w=',w)
print('b=',b)