BP算法推导:
关于BP神经网络算法的推导过程此处先不给出,博主参考的是韩力群的《人工神经网络理论、设计及应用》P47~P50。(电子书有需要的请留言)
python实现:
import numpy as np
import matplotlib.pyplot as plt
#输入数据(4,3)
X = np.array([[1,0,0],
[1,0,1],
[1,1,0],
[1,1,1]])
#标签
Y = np.array([[0,1,1,0]])
#权值初始化,1行3列,取值范围-1到1
V = (np.random.random((3,4))-0.5)*2 #输入层和隐藏层之间的权值(3,4)
W = (np.random.random((4,1))-0.5)*2 #隐藏层和输出层之间的权值(4,1)
print(W)
print(V)
#学习率设置
lr = 0.11
#激活函数sigmoid
def sigmoid(x):
return 1/(1+np.exp(-x))
#simoid求导
def dsigmoid(x):
return x*(1-x)
#权值更新
def update():
global X,Y,W,V,lr
Output_hidden = sigmoid(np.dot(X,V))#隐藏层的输出(4,4)
Output_output = sigmoid(np.dot(Output_hidden,W))#输出层的输出(4,1)
#输出层误差信号
Output_output_delta = (Y.T - Output_output)*dsigmoid(Output_output)
#隐藏层误差信号
Output_hidden_delta = Output_output_delta.dot(W.T)*dsigmoid(Output_hidden)
#权值的增量
W_C = lr*Output_hidden.T.dot(Output_output_delta)#
V_C = lr*X.T.dot(Output_hidden_delta)
W = W + W_C
V = V + V_C
for i in range(20000):
update()
if i%500==0:
Output_hidden = sigmoid(np.dot(X,V))#隐藏层的输出(4,4)
Output_output = sigmoid(np.dot(Output_hidden,W))#输出层的输出(4,1)
print('Error',np.mean(np.abs(Y.T-Output_output)))
Output_hidden = sigmoid(np.dot(X,V))#隐藏层的输出(4,4)
Output_output = sigmoid(np.dot(Output_hidden,W))#输出层的输出(4,1)
print(Output_output)
运行结果:
[[ 0.79697584]
[-0.899471 ]
[-0.03021633]
[-0.73609684]]
[[-0.83225906 0.49092062 -0.93064959 0.39906478]
[-0.4626103 0.40386656 0.90899369 -0.84416671]
[-0.12415614 -0.30768482 -0.80695638 -0.11680479]]
Error 0.006722481867364967
Error 0.006711968080461744
Error 0.006701502339927143
Error 0.006691084283599529
Error 0.006680713553112746
Error 0.0066703897938456425
Error 0.006660112654871668
Error 0.0066498817889102235
Error 0.0066396968522776535
Error 0.006629557504839739
Error 0.006619463409965184
Error 0.006609414234478989
Error 0.006599409648617244
Error 0.006589449325982851
Error 0.006579532943500827
Error 0.006569660181375644
Error 0.006559830723048313
Error 0.006550044255154733
Error 0.006540300467484534
Error 0.006530599052939871
Error 0.006520939707496551
Error 0.006511322130163323
Error 0.006501746022944472
Error 0.006492211090801057
Error 0.0064827170416130505
Error 0.006473263586143195
Error 0.006463850438000195
Error 0.006454477313602788
Error 0.006445143932144832
Error 0.006435850015560173
Error 0.0064265952884888415
Error 0.00641737947824287
Error 0.006408202314773685
Error 0.00639906353063903
Error 0.006389962860970968
Error 0.006380900043443988
Error 0.006371874818244327
Error 0.00636288692803846
Error 0.006353936117943215
Error 0.006345022135495877
[[0.00497924]
[0.99446599]
[0.99287581]
[0.0077072 ]]