github博客传送门
博客园传送门
本章所需知识:
- numpy
- matplotlib
资料下载链接:
- 深度学习基础网络模型(mnist手写体识别数据集)
梯度下降 BP 算法手动实现
import numpy as np
import matplotlib.pyplot as plt
x = np.linspace(1, 100, 100)
y = 2 * x + np.random.randn(*x.shape) * 10
step = 0.00001
diff = [0, 0]
cnt = 0
b = 0
w = 0
error0 = 0
error1 = 0
epsilon = 0.000001
def h(ax):
return w * ax + b
while True:
diff = [0, 0]
for i in range(len(x)):
diff[0] += h(x[i]) - y[i]
diff[1] += (h(x[i]) - y[i]) * x[i]
b = b - step / len(x) * diff[0]
w = w - step / len(x) * diff[1]
error1 = 0
for i in range(len(x)):
error1 += (y[i] - (b + w * x[i])) ** 2 / 2
if abs(error1 - error0) < epsilon:
break
else:
error0 = error1
plt.ion()
plt.clf()
plt.plot(x, [h(x) for x in x])
plt.plot(x, y, 'bo')
print(w, b)
plt.pause(0.1)
plt.ioff()
最后附上截图训练截图: