感知器的增量学习过程的逐步图形化显示代码

感知器的学习是非常基本的机器学习算法。网上的算法都是给定数据,画出训练后的结果。

本博客针对感知器的增量式学习过程,对每个样本的学习前后的线性判决线进行了显示。具体情况为:

1.本程序的作用是创建10个训练集,每次训练一个样本,修改线性分界线。
2.鼠标点击一次,训练一次。图形中,+ - 代表线性分类的两类区间。大的点代表下一个要训练的样本,可以推测分界线的修改方向。详见下图。
3.单步学习Single_Step_Perceptron_training函数来源于网络,找不到原网页了...

附代码:

import matplotlib.pyplot as plt
import numpy as np

#@Author lzs_jeff  2020.5.14
#本程序的作用是创建10个训练集,每次训练一个样本,修改线性分界线。
#鼠标点击一次,训练一次。图形中,+ - 代表线性分类的两类区间。大的点代表下一个要训练的样本,可以推测分界线的修改方向。
#单次学习算法来源于网络,找不到原网页了...

x1 = np.array([3, 3], ndmin=2).T
x2 = np.array([1, 1], ndmin=2).T
x3 = np.array([1.5, 1.5], ndmin=2).T
x4 = np.array([4, 3], ndmin=2).T
x5 = np.array([2, 1.5], ndmin=2).T
x6 = np.array([2, 1.8], ndmin=2).T
x7 = np.array([5, 3], ndmin=2).T
x8 = np.array([4, 4], ndmin=2).T
x9 = np.array([0, 1.5], ndmin=2).T
x10 = np.array([4, 5], ndmin=2).T

w = np.random.rand(2).reshape(1, 2)
b = np.random.rand(1)
print(w, b)
lr = 0.6  # 学习率
input =   [x1, x2, x3, x4, x5, x6, x7, x8, x9, x10]
targets = [1, -1, -1,  1,  -1, -1, 1,  1,  -1, 1]
error_num = 0
n = 10  #当前数据为10个,可以自行扩充
current_index = 0

def Single_Step_Perceptron_training(X, w, b, lr, label):
    global error_num
    if label * (np.dot(w, X) + b) <= 0:
        # 利用随机梯度下降算法,更新w,b
        w += lr * label * X.T
        b += lr * label
        error_num = error_num +1
    return w,b

def onclick(event):
    global   current_index, w, b, error_num
    if(current_index < n):
        plt.cla()
        X = input[current_index]
        label = targets[current_index]
        w, b = Single_Step_Perceptron_training(X, w, b, lr, label)

        print(current_index, n)
        i=0
        while i <= current_index:
            if(targets[i] == 1):
                plt.scatter(input[i][0, 0], input[i][1, 0], s=30.0, c='b')
            else:
                plt.scatter(input[i][0, 0], input[i][1, 0], s=30.0, c='r')
            i = i + 1
        print(current_index)
        x = np.array([-5.0, -4.0, -3.0, -2.0, -1.0, 0.0, 1.0, 2.0, 3.0, 4.0, 5.0])
        y = -(w[0][0] / w[0][1]) * x - b / w[0][1]
        plt.plot(x, y, 'g')

        current_index = current_index + 1
        if current_index == n:
            current_index =0 #重复回到第1个数据学习

        if current_index < n:
            if targets[current_index] == 1:
                plt.scatter(input[current_index][0, 0], input[current_index][1, 0], s=100.0, c='b')
            else:
                plt.scatter(input[current_index][0, 0], input[current_index][1, 0], s=100.0, c='r')

        plt.title('绘制感知器学习过程-错误次数:' + str(error_num) + '学习率:' + str(lr))
        plt.ylim(-5, 5)
        plt.xlim(-5, 5)
        labelxy = np.dot(w, np.array([-4, -4]))
        if (labelxy > 0):
            plt.text(-4, -4, '+', fontsize=30)
        else:
            plt.text(-4, -4, '-', fontsize=30)
        labelxy = np.dot(w, np.array([4, 4]))
        if (labelxy > 0):
            plt.text(4, 4, '+', fontsize=30)
        else:
            plt.text(4, 4, '-', fontsize=30)
        plt.show()


if __name__ == '__main__':
    plt.rcParams['font.sans-serif'] = ['SimHei']
    fig = plt.figure()

    #  下面进行事件连接操作
    fig.canvas.mpl_connect('button_press_event', onclick)
    # 其实mpl_connect只能接受两个参数,一个是你想触发的事件名称,一个就是回调函数

    #下面绘制第一个待预测点和初始分割线
    if targets[current_index] == 1:
        plt.scatter(input[current_index][0, 0], input[current_index][1, 0], s=100.0, c='b')
    else:
        plt.scatter(input[current_index][0, 0], input[current_index][1, 0], s=100.0, c='r')

    #x = np.random.rand(50) * 5
    x = np.array([-5.0, -4.0, -3.0, -2.0, -1.0, 0.0, 1.0, 2.0, 3.0, 4.0, 5.0])
    y = -(w[0][0] / w[0][1]) * x - b / w[0][1]
    print(x)
    print(y)
    plt.plot(x, y, 'g')
    plt.title('绘制感知器学习过程-错误次数:' + str(error_num) + '学习率:' + str(lr))
    plt.ylim(-5, 5)
    plt.xlim(-5, 5)
    plt.xlabel('x')
    plt.ylabel('y')
    labelxy =  np.dot(w, np.array([-4, -4]))
    if(labelxy > 0):
        plt.text(-4, -4, '+', fontsize=30)
    else:
        plt.text(-4, -4, '-', fontsize=30)
    labelxy = np.dot(w, np.array([4, 4]))
    if (labelxy > 0):
        plt.text(4, 4, '+', fontsize=30)
    else:
        plt.text(4, 4, '-', fontsize=30)

    plt.show()

 

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值