感知机的理论参考http://blog.csdn.net/cymy001/article/details/77992416
from IPython.display import Image
%matplotlib inline
# Added version check for recent scikit-learn 0.18 checks
from distutils.version import LooseVersion as Version
from sklearn import __version__ as sklearn_version
from sklearn import datasets
import numpy as np
iris = datasets.load_iris() #http://scikit-learn.org/stable/auto_examples/datasets/plot_iris_dataset.html
X = iris.data[:, [2, 3]]
y = iris.target #取species列,类别
if Version(sklearn_version) < '0.18':
from sklearn.cross_validation import train_test_split
else:
from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.3, random_state=0) #train_test_split方法分割数据集
from sklearn.preprocessing import StandardScaler
sc = StandardScaler() #初始化一个对象sc去对数据集作变换
sc.fit(X_train) #用对象去拟合数据集X_train,并且存下来拟合参数
X_train_std = sc.transform(X_train)
X_test_std = sc.transform(X_test)
from sklearn.linear_model import Perceptron
#http://scikit-learn.org/stable/modules/generated/sklearn.linear_model.Perceptron.html#sklearn.linear_model.Perceptron
#ppn = Perceptron(n_iter=40, eta0=0.1, random_state=0)
ppn = Perceptron() #y=w.x+b
ppn.fit(X_train_std, y_train)
#验证perceptron的原理
def prelabmax(X_test_std):
pym=[]
for i in range(X_test_std.shape[0]):
py=np.dot(ppn.coef_,X_test_std[i,:].T)+ppn.intercept_
pym.append(max(py))
return pym
prelabmax(X_test_std)
def prelabindex(X_test_std,pym):
index=[]
for i in range(X_test_std.shape[0]):
py=np.dot(ppn.coef_,X_test_std[i,:].T)+ppn.intercept_
pymn=pym[i]
for j in range(3):
if py[j]==pymn:
index.append(j)
return np.array(index)
pym=prelabmax(X_test_std)
prelabindex(X_test_std,pym)
prelabindex(X_test_std,pym)==ppn.predict(X_test_std)
#Output:array([ True, True, True, True, True, True, True, True, True,
# True, True, True, True, True, True, True, True, True,
# True, True, True, True, True, True, True, True, True,
# True, True, True, True, True, True, True, True, True,
# True, True, True, True, True, True, True, True, True], dtype=bool)
即选择y=wx+b值最大的项所在的组为其类别