感知机纯numpy

感知机算法

这里前面会通过简单的讲解来理解感知机的算法,后面会有纯numpy的实例代码。

1.感知机的特征

感知机是一个线性分类的模型,有以下几点特征:

  1. 二分分类;
  2. 基于误分类进行参数优化;
  3. 属于判别模型;
  4. 优化方法基于梯度下降;

2.模型

2.1.假设有一个输入空间X,输出空间y={+1,-1}。由输入空间到输出空间用
f(x)=sign(w*x+b)来表示预测值。如图:
在这里插入图片描述
2.2 感知机的损失函数为:在这里插入图片描述
这里来解释一下这个损失函数是如何求解出来的,首先先来图解:
在这里插入图片描述

由图可知w和b的优化就是为了找到图中的分离超平面(这里平面就是分离线,如果是多维就是面或其他形式)。
其中图中任意一点x到超平面的距离为:在这里插入图片描述

我们来看一下误分类的点,对于误分类的点来说:
在这里插入图片描述
任一点到超平面S的距离为:
在这里插入图片描述
假设有M个误分点,则所有误分点到超平面S的距离为:
在这里插入图片描述
不考虑在这里插入图片描述
则损失函数为:
在这里插入图片描述

3.感知机的参数优化

既然得出了损失函数,那损失函数我们根据梯度下降的方法来优化,步骤为以下几步:
1.设定循环次数和步长
2.每次循环遍历所有点,但只对误分点进行w和b的优化
3.设置一个损失对比值,用于避免过于微小的损失迭代
4.根据误分点、循环总数、损失对比值等条件判断停止迭代
5.获取最终优化好的w和b

我们来看一下如何更新w和b:
在这里插入图片描述

4.预测

获取了最终优化好的w和b之后在进行预测,利用之前的感知机模型
sign(w*x+b)就可以获取到预测值了。在测试的过程中再与真实值做比较即可。

5.代码:


import numpy as np
from sklearn.datasets import load_iris
from sklearn.linear_model import Perceptron
from sklearn.model_selection import train_test_split
from loguru import logger

import argparse


class 手动感知机():
    """description of class"""
    def __init__(self, learning_rate=0.1,n_epoch=500,loss_tolerance=0.001):

        self.Learning_Rate = learning_rate
        self.num = n_epoch
        self.loss_tolerance=loss_tolerance
        
    def fit(self,X,y):
        m = np.size(y,axis=0)
        n_sample, n_feature = X.shape
        rnd_val = 1 / np.sqrt(n_feature)
        rng = np.random.default_rng()
        # 均匀随机初始化权重参数
        self.w = rng.uniform(-rnd_val, rnd_val, size=n_feature)
        self.b = 0
        costs = self.gradFunction(X,y,self.Learning_Rate,self.num)
        
    def gradFunction(self,X,y,Learning_Rate,num):
         n_sample, n_feature = X.shape
         loss_pred=0
         num_time=0
         while True:
             error =0
             loss =0
             num_now=0
             for i in range(n_sample):
                 y_pred = (X[i].dot(self.w)+self.b)*y[i]
                 loss +=-y[i]*y_pred
                 if y_pred<=0:
                     #只对误分类进行更新
                     self.w =self.w +self.Learning_Rate*y[i]*X[i]
                     self.b =self.b +self.Learning_Rate*y[i]
                     error+=1
             loss_diff = loss -loss_pred
             loss_pred = loss
             num_time+=1

             if num_time>num or error == 0 or abs(loss_diff)<self.loss_tolerance:
                 break;

    def predict(self, x):
        """给定输入样本,预测其类别"""
        y_pred = np.dot(self.w, x) + self.b
        return 1 if y_pred >= 0 else -1
        
    def costFunction(self,X,y,w,b):
        cost = (X.dot(w) + b).dot(y.T) * -1
        return cost

def main():
    parser = argparse.ArgumentParser(description="感知机算法Scratch实现命令行参数")
    parser.add_argument("--nepoch", type=int, default=500, help="训练多少个epoch后终止训练")
    parser.add_argument("--lr", type=float, default=0.0001, help="学习率")
    parser.add_argument("--loss_tolerance", type=float, default=0.001, help="当前损失与上一个epoch损失之差的绝对值小于该值时终止训练")
    args = parser.parse_args()
    X,y = load_iris(return_X_y=True)
    y[:50] = -1
    xtrain,xtest,ytrain,ytest = train_test_split(X[:100],y[:100],train_size=0.8,shuffle=True)
    
    model = 手动感知机(args.lr,args.nepoch,args.loss_tolerance)
    model.fit(xtrain,ytrain)
    
    n_test = xtest.shape[0]
    n_right = 0
    for i in range(n_test):
        y_pred = model.predict(xtest[i])
        if y_pred == ytest[i]:
            n_right +=1
        else:
            logger.info("该样本真实标签为:{},但是Scratch模型预测标签为:{}".format(ytest[i],
            y_pred))
    logger.info("Scratch模型在测试集上的准确率为:{}%".format(n_right * 100 / n_test))

    skmodel = Perceptron(max_iter=args.nepoch)
    skmodel.fit(xtrain, ytrain)
    logger.info("sklearn模型在测试集上准确率为:{}%".format(100 * skmodel.score(xtest, ytest)))

if __name__ == "__main__":
    main()

在这里插入图片描述
预测的结果

  • 0
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值