简介
感知机是一种较为简单的二分类模型,但由简至繁,感知机却是神经网络和支持向量机的基础。感知机旨在学习能够将输入数据划分为+1/-1的线性分离超平面,所以说整体而言感知机是一种线性模型。因为是线性模型,所以感知机的原理并不复杂。
感知机原理
假设输入x表示为任意实例的特征向量,输出y={+1,-1}为实例的类别。感知机定义由输入到输出的映射函数如下:
![1](https://i-blog.csdnimg.cn/blog_migrate/6066666689511b31fd6004acfa637727.jpeg)
其中sign符号函数为:
![2](https://i-blog.csdnimg.cn/blog_migrate/abb7e5a9eb6ec1962336f9db31481b98.png)
w和b为感知机模型参数,也是感知机要学习的东西。w和b构成的线性方程wx+b=0极为线性分离超平面。
![3](https://i-blog.csdnimg.cn/blog_migrate/23f43bc575d25f0b62862246c1fb4933.jpeg)
假设数据是线性可分的,当然有且仅在数据线性可分的情况下,感知机才能奏效。感知机模型简单,但这也是其缺陷之一。所谓线性可分,也即对于任何输入和输出数据都存在某个线性超平面wx+b=0能够将数据集中的正实例点和负实例点完全正确的划分到超平面两侧,这样数据集就是线性可分的。
感知机的训练目标就是找到这个线性可分的超平面。为此,定义感知机模型损失函数如下:
![4](https://i-blog.csdnimg.cn/blog_migrate/db06ae042f7a40c221a90ad7a1318de6.jpeg)
要优化这个损失函数,可采用梯度下降法对参数进行更新以最小化损失函数。计算损失函数关于参数w和b的梯度如下:
![5](https://i-blog.csdnimg.cn/blog_migrate/ec06a458b7b213488157816b01a70e06.jpeg)
由上可知完整的感知机算法包括参数初始化、对每个数据点判断其是否误分,如果误分,则按照梯度下降法更新超平面参数,直至没有误分类点。
![6](https://i-blog.csdnimg.cn/blog_migrate/75f37be481e4027899aa532ac792f236.jpeg)
以上便是感知机算法的基本原理。当然这里说的感知机仅限于单层的感知机模型,仅适用于线性可分的情况。
代码实现
import copy
from matplotlib import pyplot as plt
from matplotlib import animation
training_set = [[(1, 2), 1], [(2, 3), 1], [(3, 1), -1], [(4, 2), -1]] # 训练数据集
w = [0, 0] # 参数初始化
b = 0
history = [] # 用来记录每次更新过后的w,b
def update(item):
"""
随机梯度下降更新参数
:param item: 参数是分类错误的点
:return: nothing 无返回值
"""
global w, b, history # 把w, b, history声明为全局变量
w[0] += 1 * item[1] * item[0][0] # 根据误分类点更新参数,这里学习效率设为1
w[1] += 1 * item[1] * item[0][1]
b += 1 * item[1]
history.append([copy.copy(w), b]) # 将每次更新过后的w,b记录在history数组中
def cal(item):
"""
计算item到超平面的距离,输出yi(w*xi+b)
(我们要根据这个结果来判断一个点是否被分类错了。如果yi(w*xi+b)>0,则分类错了)
:param item:
:return:
"""
res = 0
for i in range(len(item[0])): # 迭代item的每个坐标,对于本文数据则有两个坐标x1和x2
res += item[0][i] * w[i]
res += b
res *= item[1] # 这里是乘以公式中的yi
return res
def check():
"""
检查超平面是否已将样本正确分类
:return: true如果已正确分类则返回True
"""
flag = False
for item in training_set:
if cal(item) <= 0: # 如果有分类错误的
flag = True # 将flag设为True
update(item) # 用误分类点更新参数
if not flag: # 如果没有分类错误的点了
print("最终结果: w: " + str(w) + "b: " + str(b)) # 输出达到正确结果时参数的值
return flag # 如果已正确分类则返回True,否则返回False
if __name__ == "__main__":
for i in range(1000): # 迭代1000遍
if not check(): break # 如果已正确分类,则结束迭代
# 以下代码是将迭代过程可视化
# 首先建立我们想要做成动画的图像figure, 坐标轴axis,和plot element
fig = plt.figure()
ax = plt.axes(xlim=(0, 2), ylim=(-2, 2))
line, = ax.plot([], [], 'g', lw=2) # 画一条线
label = ax.text([], [], '')
def init():
line.set_data([], [])
x, y, x_, y_ = [], [], [], []
for p in training_set:
if p[1] > 0:
x.append(p[0][0]) # 存放yi=1的点的x1坐标
y.append(p[0][1]) # 存放yi=1的点的x2坐标
else:
x_.append(p[0][0]) # 存放yi=-1的点的x1坐标
y_.append(p[0][1]) # 存放yi=-1的点的x2坐标
plt.plot(x, y, 'bo', x_, y_, 'rx') # 在图里yi=1的点用点表示,yi=-1的点用叉表示
plt.axis([-6, 6, -6, 6]) # 横纵坐标上下限
plt.grid(True) # 显示网格
plt.xlabel('x1') # 这里我修改了原文表示
plt.ylabel('x2') # 为了和原理中表达方式一致,横纵坐标应该是x1,x2
plt.title('Perceptron Algorithm (www.hankcs.com)') # 给图一个标题:感知机算法
return line, label
def animate(i):
global history, ax, line, label
w = history[i][0]
b = history[i][1]
if w[1] == 0: return line, label
# 因为图中坐标上下限为-6~6,所以我们在横坐标为-7和7的两个点之间画一条线就够了,这里代码中的xi,yi其实是原理中的x1,x2
x1 = -7
y1 = -(b + w[0] * x1) / w[1]
x2 = 7
y2 = -(b + w[0] * x2) / w[1]
line.set_data([x1, x2], [y1, y2]) # 设置线的两个点
x1 = 0
y1 = -(b + w[0] * x1) / w[1]
label.set_text(history[i])
label.set_position([x1, y1])
return line, label
print("参数w,b更新过程:", history)
anim = animation.FuncAnimation(fig, animate, init_func=init, frames=len(history), interval=1000, repeat=True,
blit=True)
plt.show()