# -*- coding: utf-8 -*-
"""
Created on Thu Oct 15 13:58:06 2015
@author: Think
"""
#感知器算法
import mkdata as mk
import numpy as np
import matplotlib.pyplot as plt
N = 100 #生成测试点的数目
def check(item, y, w, b):
ans = w[0]*item[0] + w[1]*item[1] + b
ans *= y
if ans > 0:
return True
else:
return False
def perceptron(X,y):
iterNums = 1000
m,n = X.shape
w = np.zeros(m)
b = 0
a = 0.01
for i in range(iterNums):
for j in range(n):
if not check(X[:,j], y[0][j], w, b):
w = w + a * y[0][j]*X[:,j]
b += a*y[0][j]
return (w, b)
if __name__ == '__main__':
(X,y,w) = mk.mk_data(N) #是线性可分
#(X,y,w) = mk.mk_data(N,True) # 不是线性可分
plt.scatter(X[0,y[0]==1], X[1,y[0]==1], color='red')
plt.scatter(X[0,y[0]==-1], X[1,y[0]==-1], color='green')
w, b = perceptron(X, y)
x = np.arange(-2,2,0.1)
x2 = (-b-w[0]*x)/w[1]
plt.plot(x,x2)
plt.show()
mk_data函数的链接为:python生成测试数据点
如果是线性可分数据点,结果如下:
如果数据点不是线性可分的,效果如下: