符合线性可分的情况下
import numpy as np
import matplotlib.pyplot as plt
label_a = np.random.normal(6, 2, size=(50, 2))
label_b = np.random.normal(-6, 2, size=(50, 2))
plt.scatter(*zip(*label_a))
plt.scatter(*zip(*label_b))
label_a_x = label_a[:, 0]
label_b_x = label_b[:, 0]
def f(x, w, b):
return w * x + b
k_and_b = []
for i in range(100):
k, b = (np.random.random(size=(1, 2)) * 10 - 5)[0]
print(k, b)
if np.max(f(label_a_x, k, b)) <= -1 and np.min(f(label_b_x, k, b)) >= 1:
print(k, b)
k_and_b.append((k, b))
for k, b in k_and_b:
x = np.concatenate((label_a_x, label_b_x))
plt.plot(x, f(x, k, b))
print(k_and_b)
# 取k最小的值
w, b = min(k_and_b, key=lambda k_b: k_b[0])
all_x = np.concatenate((label_a_x, label_b_x))
plt.plot(all_x, f(all_x, w, b), 'r-o')
plt.show()