运用感知机模型实现对鸢尾花分类
运用感知机实现对鸢尾花进行分类
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False
df=pd.read_csv('http://archive.ics.uci.edu/ml/machine-learning-databases/iris/iris.data', header=None)
class Perceptron():
"""自定义感知机算法"""
def __init__(self,learning_rate=0.01,num_iter=50,random_state=1):
self.learning_rate=learning_rate
self.num_iter=num_iter
self.random_state=random_state
def fit(self,x,y):
rgen=np.random.RandomState(self.random_state)
self.w=rgen.normal(loc=0.0,scale=0.01,size=1+x.shape[1])
self.errors=[]
for _ in range(self.num_iter):
errors=0
for x_i,target in zip(x,y):
update=self.learning_rate*(target-self.predict(x_i))
self.w[1:]+=update*x_i
self.w[0]+=update
errors+=int(update!=0.0)
self.errors.append(errors)
return self
def predict_input(self,x):
return np.dot(x,self.w[1:])+self.w[0]
def predict(self,x):
return np.where(self.predict_input(x)>=0.0<