import numpy as np
import matplotlib.pyplot as plot
import random
import time
def randData(w, b, num):
x_list = []
y_list = []
for _ in range(num):
x = np.random.random_sample(1) * 20 - 10
y = np.random.random_sample(1) * 20 - 10
res = 1 if w1 * x + b1 > y else -1
x_list.append([x, y])
y_list.append(res)
return x_list, y_list
def show():
plot.xlim(xmax = -10, xmin = 10)
plot.ylim(ymax = -10, ymin = 10)
plot.plot(x, w1* x + b1, "r", linewidth = 3)
for i, item in enumerate(sample_x):
if sample_y[i] > 0:
plot.scatter(item[0], item[1],c='r', alpha=0.3)
else:
plot.scatter(item[0], item[1], c='g', alpha=0.3)
# 寻找误分类点
def train_one(x, y, epochs, lr, w, b):
# x 输入数据
# y 真实结果
# ...
line_x = np.linspace(-10, 10, 1000)
for epoch in range(epochs):
print("epoch:%d"%(epoch + 1))
error_arr = []
for i, item in enumerate(x):
# 如果
# w1 * x + b1 < y -> 1 y1 - y < 0 y1 < y
# w1 *x + b1 > y -> -1 y1 - y > 0 y1 > y
if ((item[0] * w + b) - item[1]) * y[i] < 0:
# 预测错误
error_arr.append(i)
if len(error_arr) > 0:
index = random.choice(error_arr)
w += lr * y[index] * x[index][0]
b += lr * y[index] * 1
reshow(w, b, line_x)
return w, b
# for i in range(epochs):
def reshow(w, b, line_x):
line_ideal = w* line_x + b
line.set_xdata(line_x)
line.set_ydata(line_ideal)
print("update")
# show()
plot.draw()
plot.pause(0.2)
if __name__ == "__main__":
# 初始化w ,b
w = np.random.random_sample(1) * 10 - 5
b = np.random.random_sample(1) * 10 - 5
w1 = np.random.random_sample(1) * 10 - 5
b1 = np.random.random_sample(1) * 10 - 5
x = np.linspace(-10, 10, 1000)
line_ideal = w* x + b
sample_x, sample_y = randData(w, b, 500)
plot.ion()
show()
plot.show()
epochs = 10
learn_rate = 0.01
line, = plot.plot(x, line_ideal, "g", linewidth = 3)
print(w)
train_one(sample_x, sample_y, 30, 0.05, w, b)
python 实现感知机
最新推荐文章于 2024-07-18 22:17:35 发布