机器学习——感知机学习

8 篇文章 0 订阅
7 篇文章 0 订阅
这篇博客详细介绍了如何实现统计学习方法中的感知机学习算法。通过简单的模型阐述,配合Python代码展示,包括二维和三维图形的绘制,帮助读者深入理解感知机的工作原理。
摘要由CSDN通过智能技术生成

本篇博客是实现《统计学习方法》中的第二章所讲述的感知机学习 ,这是一个很简单的模型,下面给出算法,


下面给出python代码,包括画出图像

#!/usr/bin/env python
# encoding: utf-8

import matplotlib.pyplot as plt
import numpy as np

x_list = [[3,3],[4,3],[1,2],[1,1],[2,1]]#数据
y_label = [1,1,1,-1,-1]#标签,1为正,-1为负

def perceptron(x_list,y_label,w,b,step):#感知机算法,给出w和b初始值,step为步长
    i = 1
    temp = 1
    length = len(y_label)
    while(bool(i)):#知道没有误分点为止
        index = i%(length+1) - 1
        x = np.array(x_list[index]).T#步骤1,选取数据,并矩阵化
        w = np.array(w).T
        flag = y_label[index]*(np.inner(x,w)+b)#判断是否为误分点
        if flag<= 0 :#是误分点,更新w 和 b
            w = w + step* y_label[index]*x
            b = b + step* y_label[index]
            temp = 0
        i += 1
        if index == length-1 and temp == 1:#如果没有误分点,跳出循环
            return w,b
            break
        temp = 1

def drawline(plt,w,b):#画直线
   x = np.arange(0.0,5.0,0.01)
   y = (-b-w[0]*x)/w[1]
   plt.plot(x,y)

plt.axis([0,5,0,5])#设置x和y轴区间
for x in x_list:#画数据点
    plt.plot(x[0],x[1],'yo-')
plt.title('perceptron')
plt.xlabel('x axis')
plt.ylabel('y axis')
w,b = perceptron(x_list,y_label,[0,0],0,1)
drawline(plt,w,b)
plt.show()

算法很简单,同时算法中给出了相关的注释,应该不难理解,下面给出结果图


下面给出3D的版本


<span style="font-size:12px;">#!/usr/bin/env python
# encoding: utf-8

import matplotlib.pyplot as plt
import numpy as np
from matplotlib import cm
from mpl_toolkits.mplot3d import Axes3D

x_list = [[3,3,2],[4,3,1],[1,2,3],[1,1,2],[2,1,2]]
y_label = [1,1,1,-1,-1]

def perceptron(x_list,y_label,w,b,step):
    i = 1
    temp = 1
    length = len(y_label)
    while(bool(i)):
        index = i%(length+1) - 1
        x = np.array(x_list[index]).T
        w = np.array(w).T
        flag = y_label[index]*(np.inner(x,w)+b)
        if flag<= 0 :
            w = w + step* y_label[index]*x
            b = b + step* y_label[index]
            temp = 0
            print w,b
        i += 1
        if index == length-1 and temp == 1:
            print w,b
            return w,b
            break
        temp = 1

def drawline(plt,w,b):
   x = np.arange(0.0,3.0,0.1)
   y = np.arange(0.0,3.0,0.1)
   x,y = np.meshgrid(x,y)
   z = (-b-w[0]*x-w[1]*x)/w[2]
   ax.plot_surface(x,y,z, rstride=1, cstride=1, cmap=cm.coolwarm, linewidth=0, antialiased=False)

fig = plt.figure()
ax = fig.add_subplot(1,1,1,projection='3d')
for x in x_list:
    ax.scatter(x[0],x[1],x[2],c='b')
w,b = perceptron(x_list,y_label,[0,0,0],0,1)
drawline(plt,w,b)
plt.show()</span>
结果如下


  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值