感知机是一种较为简单的二分类模型,但是感知机却是神经网络和支持向量机的基础。其本质是学习能够将输入的数据划分为-1和1两类的线性分离超平面,所以说感知机是一种线性模型。
模型原理:
若输入x表示为任意实例的特征向量,输出y={+1,-1}为该实例的类别。其输入输出可以用以下函数表示:
sign表示的意思如下所示:
上面的函数中,w,b为模型的参数,也是感知机要学习的东西,w和b构成的线性方程w*x+b为线性分离超平面。
只有当所给数据为线性可分的情况下,感知机才能奏效(所谓线性可分就是对任何输入和输出数据都存在有一个线性超平面能够将数据集中的正、负实例划分到超平面两边)。感知机的训练目标就是要找到这个超平面,在训练过程中优化的模型损失函数如下所示:
(函数解释:(yi(wxi+b)该函数当wxi+b大于零且得到的分类结果为1则说明该数据分类正确两个值同号相乘得正前面再加一个负号变负,所以当所有的yi和(w*xi+b)都同号时得到的损失函数值会达到最小,得到的超平面的分类效果好。)
优化损失函数:
梯度下降法:(关于w和b的梯度如下)
所以由上述可知,感知机算法的步骤包括:参数初始化、对每个数据点判断是否误分,如果误分则按照梯度下降法更新超平面函数的参数,直至没有误分点。
代码实现如下:
import pandas as pd
import numpy as np
from sklearn.datasets import load_iris
import matplotlib.pyplot as plt
#导入iris数据
iris = load_iris()
df = pd.DataFrame(iris.data, columns = iris.feature_names)
df[‘label’] = iris.target
df.columns = [‘sepal length’, ‘sepal width’, ‘petal length’, ‘petal width’, ‘label’]
#绘制散点图
plt.scatter(df[:50][‘sepal length’], df[:50][‘sepal width’], c=‘red’, label=‘0’)
plt.scatter(df[50:100][‘sepal length’], df[50:100][‘sepal width’], c=‘green’, label=‘1’)
plt.xlabel(‘sepal length’)
plt.ylabel(‘sepal width’)
plt.legend()
#plt.show()
data = np.array(df.iloc[:100, [0, 1, -1]])
x, y = data[:, :-1], data[:, -1]
y = np.array([1 if i ==1 else -1 for i in y])
def initilize_with_zeros(dim):
w = np.zeros(dim)
b = 0.0
return w, b
#定义sign符号函数
def sign(x, w, b):
return np.dot(x, w)+b
#定义感知机训练函数
def train(x_train,y_train, learning_rate):
w, b = initilize_with_zeros(x_train.shape[1])
is_wrong = False
while not is_wrong:
wrong_count = 0
for i in range(len(x_train)):
x = x_train[i]
y = y_train[i]
if y * sign(x, w, b)<=0:
w = w + learning_rate *np.dot(y, x)
b = b + learning_rate*y
wrong_count +=1
if wrong_count == 0:
is_wrong = True
print('There is no missclassification!')
params = {'w': w,
'b': b
}
return params
params = train(x,y,0.01)
x_points = np.linspace(4, 7, 10)
y_hat = -(params[‘w’][0]*x_points + params[‘b’])/params[‘w’][1]
plt.figure()
plt.plot(x_points, y_hat)
plt.scatter(data[:50, 0], data[:50, 1], color=‘red’, label=‘0’)
plt.scatter(data[50:100, 0], data[50:100, 1], color=‘green’, label=‘1’)
plt.xlabel(‘sepal length’)
plt.ylabel(‘sepal width’)
plt.legend()
plt.show()