1. 数据
2.
import matplotlib.pyplot as plt
import numpy as np
x = np.loadtxt('data.txt')
b = np.ones(40)
y = np.insert(x, 0, b, 1) # 增广
class BatchPerception():
def __init__(self, w1, w2, y):
self.w1 = w1
self.w2 = w2
self.a = np.zeros(3)
self.count = 0
self.lr = 1
self.y = y
def preprocess(self):
y_temp = self.y.copy()
y_w1 = y_temp[(self.w1 - 1) * 10:self.w1 * 10, 0:3]
y_w2 = -1 * y_temp[(self.w2 - 1) * 10:self.w2 * 10, 0:3] # 规范化
y_w = np.concatenate((y_w1, y_w2), axis=0)
return y_w
def train(self):
y_w = self.preprocess()
for j in range(1000):
Y = []
for i in range(20):
if np.inner(self.a, y_w[i]) <= 0:
Y.append(y_w[i])
# print(np.inner(self.a, y[i][0:3]))
if len(Y) == 0:
print(self.w1, '和', self.w2, self.a, self.count)
break
Y_sum = np.sum(Y, axis=0)
self.a = self.a + self.lr * Y_sum
self.count += 1
def visualization(self