python 实现感知机(perceptron)

一、基本原理

感知机(perceptron)是一个二类分类的线性分类模型,其几何意义是寻找一个超平面将点(特征空间)划分为正类和负类。本文以二维平面为例,实现一个简单的感知机模型。

二、实现思路

在二维平面中,感知机的训练过程即寻找一条直线,这条直线可以将平面中线性可分的点分离开,代码实现思路如下:

  1. 生成训练数据:
    为了保证数据是线性可分的,在生成数据前确定两个点(如:(2, 2)、(6, 6)),在这两个点的周围随机生成数据,分别给这两个点周围的数据加上不同的标签 -1 和 +1。
  2. 训练模型(获得直线参数)
    我们不妨设要寻找的直线方程为:w[0] * x + w[1] * y + b,初始化参数 w = [0., 1.] 和 b = 0,即直线的初始方程为 y = 0。
    接着从训练集中取出一个点,将这个点带入到目前训练的直线中,如果求出的值和该点的标签乘积小于等于零,说明直线没有将这个点正确分类,这时更新 w[0],w[1] 和 b 的值。
    更新规则为:w[0] += 学习率 * 标签值 * 点的横坐标,w[1] += 学习率 * 标签值 * 点的纵坐标,b += 学习率 * 标签值。
    遍历训练集的所有点,如果点没有正确分类就按上述更新规则更新 w 和 b 的值,直到所有点都被正确分类为止。根据得到的 x, y 的参数 w[0], w[1] 和 b 的值,计算出直线的斜率和截距。
  3. 将训练数据(点)和直线画出
    根据步骤 2 得到的斜率和截距画出直线,根据训练集的点和点的标签画出点。

三、源代码

"""
@description: perceptron
@author: Zhao Chengcheng
"""

import numpy as np
import matplotlib.pyplot as plt


def get_data(num):
    """
    @description: 随机生成数据
    @param num: 数据条数
    @return data: 点的坐标
    @return label: 每个点的标签,为:-1 或 +1
    """
    data = [] # 存放随机生成的坐标 Xn
    label = [] # 存放数据的标签, -1 或者 +1
    x1 = np.random.normal(2, 0.8, int(num / 2))
    y1 = np.random.normal(2, 0.8, int(num / 2)) # 在点 (2, 2) 周围生成点
    x2 = np.random.normal(6, 0.8, int(num / 2))
    y2 = np.random.normal(6, 0.8, int(num / 2)) # 在点 (6, 6) 周围生成点,保证生成的点是可被划分的
    for i in range(num):
        if i < num / 2:
            data.append([x1[i], y1[i]])
            label.append(-1)
        else:
            data.append([x2[int(i - num / 2)], y2[int(i - num / 2)]])
            label.append(1)
    return data, label


def perceptron(data, label, eta):
    """
    训练感知机
    @param data: 包含坐标的数据
    @param label: 数据的标签 -1 或者 +1
    @param eta: 学习率
    @return slope: 斜率
    @return intercept: 截距
    """
    w = [0., 1.0] # 直线 x 和 y 的系数
    b = 0.
    separated = False # 标记是否已将点完全分离
    while not separated:
        separated = True
        for i in range(len(data)):
            if label[i] * (w[0] * data[i][0] + w[1] * data[i][1] + b) <= 0:
                separated = False # 没有完全分离
                w[0] += eta * label[i] * data[i][0] # 更新 w 的值
                w[1] += eta * label[i] * data[i][1]
                b += eta * label[i] # 更新 b 的值
    slope = -w[0] / w[1]    # 斜率
    intercept = -b / w[1]   # 截距
    return slope, intercept


def plot(data, label, slope, intercept):
    """
    @description: 画出点和超平面(直线)
    @param data: 点的坐标
    @param label: 点的标签
    @param slope: 直线的斜率
    @param intercept: 直线的纵截距
    """
    plt.rcParams['font.sans-serif'] = ['SimHei'] # 设置字体
    plt.rcParams['axes.unicode_minus'] = False
    plt.xlabel('X')
    plt.ylabel('Y')
    area = np.pi * 2 ** 2 # 点的面积

    data_mat = np.array(data)
    X = data_mat[:, 0]
    Y = data_mat[:, 1]
    for i in range(len(label)):
        if label[i] > 0:
            plt.scatter(X[i].tolist(), Y[i].tolist(), s=area, color='red')  # 画点
        else:
            plt.scatter(X[i].tolist(), Y[i].tolist(), s=area, color='green')
    # 根据斜率和截距画出直线
    axes = plt.gca()
    x_vals = np.array(axes.get_xlim())
    y_vals = intercept + slope * x_vals
    plt.plot(x_vals, y_vals)
    plt.show()


data, label = get_data(100) # 生成数据和标签
slope, intercept = perceptron(data, label, 1) # 训练模型,得到直线的斜率和截距
plot(data, label, slope, intercept) # 画出点和直线

四、代码运行结果

感知机运行结果1
感知机运行结果2
源码地址:感知机原始形式实现

  • 4
    点赞
  • 45
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值