1. 原理
单层感知器的原始形式如下:
如果把偏置项当作特殊权值,则单层感知器可以改为如下形式:
感知器的学习规则
学习率:
(1)学习率太大,容易造成权值调整不稳定。
(2)学习率太小,权值调整太慢,迭代次数太多。
收敛条件:
(1)误差小于某个预先设定的较小的值。
(2)两次迭代之间的权值变化已经很小。
(3)设定最大迭代次数,当迭代次数超过最大次数就停止。
2. 样例代码
Perceptron.py
# -*- encoding: utf-8 -*-
import numpy as np
class Perceptron:
def __init__(self):
self.W = None
self.trainTimes = 0
def train(self, X, Y, n, learningRate):
'''training'''
# random W
row, column = X.shape
self.W = (np.random.random(column) - 0.5) * 2
# training for n times
for i in range(n):
# training
self.trainTimes += 1
output = np.sign(np.dot(X, self.W.T))
gain = learningRate * ((Y - output).dot(X)) / row
self.W += gain
# check
newOutput = np.sign(np.dot(X, self.W.T))
if (newOutput == Y).all():
break
def getW(self):
return self.W
def getTrainTimes(self):
return self.trainTimes
def predict(self, x):
return np.sign(np.dot(x, self.W.T))
Test4Perceptron.py
# -*- encoding: utf-8 -*-
from Perceptron import *
import numpy as np
import matplotlib.pyplot as plt
def test():
# training data
X = np.array([[1, 3, 3], [1, 4, 3], [1, 1, 1]])
Y = np.array([1, 1, -1])
learningRate = 0.1
learningTimes = 100
# training
perceptron = Perceptron()
perceptron.train(X, Y, learningTimes, learningRate)
W = perceptron.getW()
trainTimes = perceptron.getTrainTimes()
print W
print trainTimes
# plot the training data
X1 = [3, 4]
Y1 = [3, 3]
X2 = [1]
Y2 = [1]
k = -W[1] / W[2]
d = -W[0] / W[2]
x = np.linspace(0, 5) # generate arithmetic sequence
# plot
plt.figure()
plt.plot(x, x*k+d, 'r')
plt.plot(X1, Y1, 'bo')
plt.plot(X2, Y2, 'yo')
plt.show()
# predict
test = np.array([[1, 2, 3], [1, 6, 2], [1, 3, 3], [1, 7, 5], [1, 5, 7], [1, 9, 2]])
testResult = perceptron.predict(test)
print testResult
testX1 = []
testY1 = []
testX2 = []
testY2 = []
for i in range(len(testResult)):
if testResult[i] >= 0:
testX1.append(test[i][1])
testY1.append(test[i][2])
else:
testX2.append(test[i][1])
testY2.append(test[i][2])
plt.figure()
plt.plot(x, x*k+d, 'r')
plt.plot(testX1, testY1, 'bo')
plt.plot(testX2, testY2, 'yo')
plt.show()
if __name__ == '__main__':
test()