weight = 0.1
alpha = 0.01
def neural_network(input, weight): #构建了一个预测函数
prediction = input * weight
return prediction
# 2) PREDICT: Making A Prediction And Evaluating Error
number_of_toes = [8.5]
win_or_lose_binary = [1] # (won!!!)
print("输入因子 number_of_toes是:", number_of_toes)
print("win_or_lose_binary r 是:", win_or_lose_binary )
input = number_of_toes[0] #将数组number_of_toes[] 的第一个值赋给input
goal_pred = win_or_lose_binary[0] #将数组win_or_lose_binary[0] 的第一个值赋给goal_pred
print("数据 input 是:", input )
print("数据 goal_pred 是:", goal_pred )
pred = neural_network(input,weight) # 预测值
error = (pred - goal_pred) ** 2 #均方误差
print("input","*","weight","预测值 pred 是:", pred)
print("均方误差 error 是:", error )
# 3) COMPARE: Calculating "Node Delta" and Putting it on the Output Node
delta = pred - goal_pred #预测值和真实值的差值。纯误差
print("数据 纯误差 delta 是:", delta)
# 4) LEARN: Calculating "Weight Delta" and Putting it on the Weight
print("数据 input 是:", input )
weight_delta = input * delta # "缩放、负值反转和停止
#weight_ delta 是一个用于度量权重所导致的网络犯错的指标。计算它的方法是将权重的输出节点增量(delta)乘以权重的输入。
# 因此, 在我们逐个创建权重增量weight_ delta 时, 需要基于权重对应的输入的值,将输出节点增量(delta)进行缩放操作。
# 这就解释了前面提到的direction and amount 的三个属性:缩放、负值反转和停止调节。
print("缩放、负值反转和停止weight_delta 是:", weight_delta )
# 5) LEARN: Updating the Weight
weight_delta01 = weight_delta * alpha
print("数据 weight_delta01 是:", weight_delta01 )
print("^^^前面修改权重后的 weight是:", weight)
alpha = 0.01 # fixed before training
#在使用weight_delta 更新权重值前, 可将它乘以一个小数值alpha。这让你可以控制网络的学习速度。如果网络学得太快, alpha可以修正网络更新的速度,避免过度调整。稍后会对此进行详细介绍。请注意,这里权重更新的增量与冷热学斗的结果是相同的(小幅增加) 。
weight -= weight_delta * alpha
print("数据 alpha 是:", alpha )
print("修改权重后的 weight是:", weight)
pred = neural_network(input,weight) # 预测值
error = (pred - goal_pred) ** 2 #均方误差
print("第二次循环input","*","weight","预测值 pred 是:", pred)
print("第二次循环 均方误差 error 是:", error )
结果一次循环和两次循环结果:
C:\Users\admin\AppData\Local\Programs\Python\Python39\python.exe "D:\python test\BPtest.py"
输入因子 number_of_toes是: [8.5]
win_or_lose_binary r 是: [1]
数据 input 是: 8.5
数据 goal_pred 是: 1
input * weight 预测值 pred 是: 0.8500000000000001
均方误差 error 是: 0.022499999999999975
数据 纯误差 delta 是: -0.1499999999999999
数据 input 是: 8.5
缩放、负值反转和停止weight_delta 是: -1.2749999999999992
数据 weight_delta01 是: -0.012749999999999992
^^^前面修改权重后的 weight是: 0.1
数据 alpha 是: 0.01
修改权重后的 weight是: 0.11275
第二次循环input * weight 预测值 pred 是: 0.958375
第二次循环 均方误差 error 是: 0.001732640625000002
进程已结束,退出代码0