读取数据:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
# 读取数据
df = pd.read_csv("./iris.data")
y = df.iloc[0:100, 4].values
y = np.where(y == "Iris-setosa", 1, -1)
x = df.iloc[0:100, [0,2]].values
plt.scatter(x[y==1,[0]], x[y==1, [1]], color='red', marker="o", label="Iris-setosa") # 散点图
plt.scatter(x[y==-1,[0]], x[y==-1, [1]], color='blue', marker="o", label="setosa")
plt.xlabel("sepal length [cm]")
plt.ylabel("petal length [cm]")
plt.legend(loc="upper left")
plt.show()
采用自定义的感知机进行训练:
class perception:
def __init__(self, X, Y, lr = 0.01, Iter = 50):
self.lr = lr
self.Iter = Iter
self.X = X
self.Y = Y
np.random.seed(10)
self.w = np.random.normal(loc=0.0, scale=0.01, size=X.shape[1]+1)
def prediction(self, x):