import numpy as np
# sigmoid function sig函数和其导数整合
def nonlin(x, deriv=False):
if (deriv == True):
return x * (1 - x) # 如果deriv为true,求导数
return 1 / (1 + np.exp(-x)) # exp()是以e为底的指数函数
X = np.array([[0.35], [0.9]]) # 输入层
y = np.array([[0.5]]) # 输出值
np.random.seed(1)
# 初始权重
W0 = np.array([[0.1, 0.8], [0.4, 0.6]])
W1 = np.array([[0.3, 0.9]])
print("original", W0, "\n", W1)
for j in range(100):
# forward propagation
l0 = X # 相当于文章中x0
l1 = nonlin(np.dot(W0, l0)) # 相当于文章中y1
l2 = nonlin(np.dot(W1, l1)) # 相当于文章中y2
l2_error = y - l2
Error = 1 / 2.0 * (y - l2) ** 2 #1/2 *(误差)的平方
print("Error:", Error)
# back Propagation
l2_delta = l2_error * nonlin(l2, deriv=True) #中间层输出的偏导
l1_error = l2_delta * W1 # 反向传播 链式
l1_delta = l1_error * nonlin(l1, deriv=True) #输入层 输出时的偏导
print("l2_delta",l2_delta)
print("l1_error", l1_error)
print("l1_delta", l1_delta)
W1 += l2_delta * l1.T; # 修改权值
W0 += l0.T.dot(l1_delta)
print(W0, "\n", W1)
结果:
original [[0.1 0.8]
[0.4 0.6]]
[[0.3 0.9]]
Error: [[0.0181039]]
l2_delta [[-0.04068113]]
l1_error [[-0.01220434 -0.03661301]]
l1_delta [[-0.00265449 -0.00796347]
[-0.00272388 -0.00817165]]
[[0.09661944 0.78985831]
[0.39661944 0.58985831]]
[[0.27232597 0.87299836]]
Error: [[0.01652628]]
l2_delta [[-0.03944183]]
l1_error [[-0.01074104 -0.03443266]]
l1_delta [[-0.00234486 -0.00751695]
[-0.00240534 -0.00771082]]
[[0.09363393 0.78028763]
[0.39363393 0.58028763]]
[[0.2455836 0.84691021]]
Error: [[0.01506159]]
l2_delta [[-0.03816188]]
l1_error [[-0.00937193 -0.03231969]]
l1_delta [[-0.00205298 -0.00707983]
[-0.00210525 -0.0072601 ]]
[[0.09102066 0.7712756 ]
[0.39102066 0.5712756 ]]
[[0.21978966 0.82175133]]
Error: [[0.0137064]]
l2_delta [[-0.03685334]]
l1_error [[-0.00809998 -0.03028428]]
l1_delta [[-0.00177996 -0.00665494]
[-0.00182474 -0.00682234]]
[[0.08875541 0.76280627]
[0.38875541 0.56280627]]
[[0.19495316 0.79752995]]
Error: [[0.01245646]]
l2_delta [[-0.03552736]]
l1_error [[-0.00692617 -0.02833413]]
l1_delta [[-0.00152646 -0.00624455]
[-0.00156441 -0.00639983]]
[[0.08681317 0.75486083]
[0.38681317 0.55486083]]
[[0.17107608 0.7742475 ]]