线性可分:感知机

这篇博客介绍了感知机模型及其基于错误驱动的学习策略。在确保数据线性可分的前提下,通过初始化权重和偏置,利用随机梯度下降法对误分类样本进行迭代更新。实验数据使用Python和相关库展示了感知机的训练过程,并绘制了决策边界。
摘要由CSDN通过智能技术生成

感知机思想:错误驱动
模型:
f ( x ) = s i g n ( W T x ) , x ∈ R p , W ∈ R p f(x)=sign(W^Tx),x\in \R^p,W\in \R^p f(x)=sign(WTx)xRpWRp
s i g n ( a ) = { + 1 , a ≥ 0 − 1 , a < 0 sign(a)=\left\{\begin{matrix} +1,a\geq0 \\ -1,a<0 \end{matrix}\right. sign(a)={+1,a01,a<0

前提:数据是线性可分的
样本集: { x i , y i } i = 1 N \{x_i,y_i\}_{i=1}^{N} {xi,yi}i=1N
先给 W W W一个初始值 W 0 W_0 W0
D : 被 错 误 分 类 的 样 本 D:{被错误分类的样本} D:
策略:
loss function:被错误分类的点的个数
L ( W ) = ∑ i = 1 N I { y i W T x i < 0 } L(W)=\sum\limits_{i=1}^{N}I\{y_iW^Tx_i<0\} L(W)=i=1NI{yiWTxi<0}
当样本点被正确分类时: y i W T x i > 0 y_iW^Tx_i>0 yiWTxi>0
W T x i > 0 W^Tx_i>0 WTxi>0时, y i = + 1 y_i=+1 yi=+1
W T x i < 0 W^Tx_i<0 WTxi<0时, y i = − 1 y_i=-1 yi=1
那么样本点被错误分类时, y i W T x i < 0 y_iW^Tx_i<0 yiWTxi<0
但是此时 L ( W ) L(W) L(W)不可导,所以这个损失函数不合适。所以改用以下的损失函数

L ( W ) = ∑ x i ∈ D − y i W T x i L(W)=\sum\limits_{x_i\in D} -y_iW^Tx_i L(W)=xiDyiWTxi
在代码的时候可以用随机梯度下降优化

实验数据:
在这里插入图片描述

import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib 
import matplotlib.pyplot as plt

data = pd.read_csv('data/9-2-data.csv')
dataMat = np.mat(data.iloc[:,:2].values)
labelMat = np.mat(data.iloc[:,-1].values).T
m, n = np.shape(dataMat)
w = np.zeros((1, np.shape(dataMat)[1]))
#初始化偏置b为0
b = 0
#初始化步长,也就是梯度下降过程中的n,控制梯度下降速率
h = 0.0001
for k in range(50):
        #对于每一个样本进行梯度下降
        #李航书中在2.3.1开头部分使用的梯度下降,是全部样本都算一遍以后,统一
        #进行一次梯度下降
        #在2.3.1的后半部分可以看到(例如公式2.6 2.7),求和符号没有了,此时用
        #的是随机梯度下降,即计算一个样本就针对该样本进行一次梯度下降。
        #两者的差异各有千秋,但较为常用的是随机梯度下降。
        for i in range(m):
            #获取当前样本的向量
            xi = dataMat[i]
            #获取当前样本所对应的标签
            yi = labelMat[i]
            #判断是否是误分类样本
            #误分类样本特诊为: -yi(w*xi+b)>=0,详细可参考书中2.2.2小节
            #在书的公式中写的是>0,实际上如果=0,说明改点在超平面上,也是不正确的
            if -1 * yi * (w * xi.T + b) >= 0:
                #对于误分类样本,进行梯度下降,更新w和b
                w = w + h *  yi * xi
                b = b + h * yi
sns.set(style='whitegrid')
sns.scatterplot(x='x1',y='x2',hue='y',data=data,)
x=np.linspace(-1,3)
plt.plot(x,-(w.tolist()[0][0]/w.tolist()[0][1])*x+b.tolist()[0][0])
plt.show()

画图结果:
在这里插入图片描述

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值