1.pocket算法
对PLA的一种补充,面对无法线性可分的变式,寻找分错率最低的超平面。若未看过PLA的建议点开前一篇PLA的数学推导
- 二者的区别:
PLA对非线性可分数据分类,根据T<= R 2 ρ 2 \frac{R^2}{\rho^2} ρ2R2可知,无法找到一个使得 ρ ≠ 0 \rho\neq0 ρ=0的超平面,故T趋向于无穷大,导致其无法收敛。而pocket通过控制迭代次数的多少来获得更优的超平面,但是这也导致其算法性能低于PLA,每产生一个超平面就要求其分错率。 - 二者的联系
其本质都是求沿着梯度方向求最小损失的极值点,PLA能使极值最小化为0,而pocket可随着迭代次数来近似得到最小极值点。
2.代码实现
我们将两个正态分布的均值调近一些就可造成无法线性可分的数据集,将上个PLA的参数由
N
1
(
2
,
1
)
N_1(2,1)
N1(2,1),
N
2
(
−
2
,
1
)
N2(-2,1)
N2(−2,1)改为
N
1
(
1
,
1
)
N_1(1,1)
N1(1,1),
N
2
(
−
1
,
1
)
N_2(-1,1)
N2(−1,1)作为新的训练数据。
from matplotlib import pyplot as plt
import numpy as np
#生成训练数据
def generate_data(u1,o1,u2,o2,n,m):
#此产生两组正态分布数据(产出为元组数据)
t1=np.random.normal(u1,o1,size=(n,2))
t2=np.random.normal(u2,o2,size=(m,2))
a_x=np.array(t1)
b_x=np.array(t2)
#给两组正态数据打标签
a_y=np.ones(n)
b_y=np.negative(np.ones(m))
class1=np.c_[a_x,a_y]
class2=np.c_[b_x,b_y]
return class1,class2
#计算错误率
def checkErrorRate(test,w):
count=0
for i in range(len(test)):
x=np.array(test[i][:-1])
y=np.dot(x,w)
if np.sign(test[i][-1])!=np.sign(y):
count+=1
return count/len(test)
#pocket算法实现
def pocket():
w=np.zeros(3)#初始化w0
best_w=w
bestRate=1
n=50
m=50
c1,c2=generate_data(1,1,-1,1,n,m)
test=np.vstack((c1,c2))#合并两类数据
x0 = np.ones(n+m)
test = np.c_[x0, test] # 插入列向量x0=[0,0,...0]
cnt=0
while True:
cnt+=1
if cnt>1000:#pocket与pla不同的一点就在于他靠控制迭代次数来提高分类精度
break
success=True
for i in range(len(test)):
x=np.array(test[i][:-1])
y=np.dot(x,w)
if np.sign(y)==np.sign(test[i][-1]):
continue
w=w+test[i][-1]*x #更新w值
rate=checkErrorRate(test,w)#得出分错率
if rate<bestRate:#如果分错率更小则替换当前最好的w
bestRate=rate
best_w=w
success = False
break
if success==True:
break
# 绘制分类前效果
plt.scatter(c1[:, 0], c1[:, 1], c='r', marker='o') # 正类正态分布
plt.scatter(c2[:, 0], c2[:, 1], c='b', marker='x') # 负类正态分布
plt.show()
# 绘制分类后效果
x = np.linspace(min(test[:, 1]) - 1, max(test[:, 2]) + 1, 50)
y = -w[1] / w[2] * x - w[0] / w[2] # 见下文数学推导
plt.plot(x, y, c='g') # 超平面
plt.scatter(c1[:, 0], c1[:, 1], c='r', marker='o') # 正类正态分布
plt.scatter(c2[:, 0], c2[:, 1], c='b', marker='x') # 负类正态分布
plt.show()
return cnt, w
if __name__=='__main__':
cnt,w=pocket()
print("迭代次数:",cnt)
print("超平面法向量:",w)
3.分类结果
分类前效果:
分类后效果(迭代次数不够导致分类效果并不好,读者可以自己试着提供迭代次数来到达更好的分类效果):