感知机
最近开始机器学习之旅(有我一样先学习深度学习后机器学习的么。。。),所用书籍是李航老师的《统计学习方法》,关于感知机的原理比较简单,但可以其中随机梯度下降的思想很重要,这里考虑用python实现感知机的原始形式。
PS:
网站很多大佬写的代码很漂亮,但没有可视化总感觉有点不够形象,所以简单用matplot写了个训练过程的可视化。
tips:
- 程序涉及numpy和matplot两个库,由matplot提供可视化
- 数据生成阶段提供两种模式:随机区域与随机圆形区域,没有仔细来设计更好的数据簇
- 使用matplot时要记得打开交互模式,不然会一张图显示了程序就停在那里了
- 训练过程学习率(lr)不要设置得太大,不然可能会错过最优点
matplot打开交互模式
plt.ion()
关闭交互
plt.ioff()
下面附上完整的代码:(所有都塞进一个类里面了。。。)
from matplotlib import pyplot as plt
import numpy as np
import random
class Perceptron(object):
"""
实现感知机的原始形式的数据生成与训练以及过程的动态显示
"""
def __init__(self, model, lr=0.1):
"""
:param model: 生成数据时的方法:0-随机散乱;1-随机圆簇
:param lr: 学习率lr,不要太大,过大会错过最优点
"""
self.model = model
self.eta = lr
self.times_out = 2000 # 连续有2000个(当然大于数据总数)分类成功则训练结束
self.truth_w = (2, 3) # 生成数据所用的权重和偏置
self.truth_h = -800
self.p1, self.p2 = self.data_points_gen(self.truth_w, self.truth_h, 300, self.model)
# 初始化训练用的权重与偏置(全零也可以)
self.out_w = np.array([0.1, 0.1], dtype=np.float)
self.out_b = -100
self.total_step = 0
# 生成训练用数据
@staticmethod
def data_points_gen(truth_w, truth_b, data_num, model):
"""
:param truth_w: 用于生成数据的权重 (w1, w2)
:param truth_b: 用于生成数据的偏置 b0
:param data_num: 生成每种颜色points的数量
:param model: 模式选择,随机排布(0)还是圆形区域排布(1)
:return: 两组点信息 numpy array格式
"""
w_np = np.array(truth_w, dtype=np.float).T
b_np = np.array(truth_b, dtype=np.float)
point1_num = 0
point2_num = 0
point1 = []
point2 = []
assert model in [0, 1], "The argument 'model' should have value: 0 or 1"
if model == 0:
while point1_num < data_num or point2_num < data_num:
x1 = random.randint(1, 800)
x2 = random.randint(1, 800)
x_np = np.array([x1, x2], dtype=np.float)
# print(x_np)
if np.matmul(w_np, x_np) + b_np < -80:
if point1_num < data_num:
point1.append([x1, x2])
point1_num += 1
elif np.matmul(w_np, x_np) + b_np >= 80:
if point2_num < data_num:
point2.append([x1, x2])
point2_num += 1
if model == 1:
intersection = -truth_b // (truth_w[0] + truth_w[1])
circle_center1 = np.array([(0 + intersection) // 4 * 3, (0 + intersection) // 4 * 3])
circle_center2 = np.array([(800 + intersection) // 4, (800 + intersection) // 4])
# print(circle_center1, circle_center2)
while point1_num < data_num or point2_num < data_num:
x1 = random.randint(1, 800)
x2 = random.randint(1, 800)
x_np = np.array([x1, x2], dtype=np.float)
# print(x_np)
if np.matmul(w_np, x_np) + b_np < -20 and \
np.sqrt((x_np[0]-circle_center1[0])**2+(x_np[1]-circle_center1[1])**2) < 80:
if point1_num < data_num:
point1.append([x1, x2])
point1_num += 1
elif np.matmul(w_np, x_np) + b_np >= 20 and \
np.sqrt((x_np[0]-circle_center2[0])**2+(x_np[1]-circle_center2[1])**2) < 80:
if point2_num < data_num:
point2.append([x1, x2])
point2_num += 1
point1 = np.array(point1)
point2 = np.array(point2)
return point1, point2
# sign函数:x > 0则return 1,否则return -1
@staticmethod
def sign(x):
return 1.0 if x >= 0 else -1.0
# 计算w*x+b
@staticmethod
def cal_fx(w, b, x):
w = np.array(w)
b = np.array(b)
x = np.array(x).T
return np.matmul(w, x) + b
@staticmethod
def draw_img(p1, p2, out_w, out_b):
plt.clf()
plt.title('Perceptron Demo', fontsize=15)
plt.xlabel('x1', fontsize=13)
plt.ylabel('x2', fontsize=13)
plt.scatter(p1[:, 0], p1[:, 1], c='r', marker='o')
plt.scatter(p2[:, 0], p2[:, 1], c='b', marker='^')
x1 = np.linspace(10, 400, 2)
x2 = -out_w[0] / out_w[1] * x1 - out_b / out_w[1]
# x2 = x2.astype(np.int)
# print(x1, x2)
plt.plot(x1, x2, c='g', linewidth=3)
plt.pause(0.5)
plt.show()
def train(self):
times_counter = 0
plt.ion() # 打开交互模式,使可以动态刷新
while True:
self.total_step += 1
rad_int = random.randint(0, 10000)
label = rad_int % 2
if label == 0:
y = -1
rad_int = rad_int % self.p1.shape[0]
point_x = self.p1[rad_int]
else:
y = 1
rad_int = rad_int % self.p2.shape[0]
point_x = self.p2[rad_int]
# print(point_x)
out = self.cal_fx(self.out_w, self.out_b, point_x)
out_sign = self.sign(out)
if y * out_sign > 0:
times_counter += 1
else:
times_counter = 0
self.out_w += self.eta * y * point_x
self.out_b += self.eta * y
if self.total_step % 100 == 0:
print("Total train step: {}".format(self.total_step))
print('out_w', self.out_w, 'out_b', self.out_b)
self.draw_img(self.p1, self.p2, self.out_w, self.out_b)
if times_counter >= self.times_out:
break
print("At last, the weight and bias after train is:")
print("wight:", self.out_w)
print("bias:", self.out_b)
plt.ioff() # 关闭画图的窗口,即关闭交互模式
plt.show()
if __name__ == "__main__":
ptron = Perceptron(model=1, lr=0.1)
ptron.train()
初始matplot图像:
完成训练最终图像: