一、目的
模拟批量梯度下降算法,计算在x_data、y_data数据集下, y = ω x y={\omega}x y=ωx模型找到合适的 ω \omega ω的值
二、编程
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
# 数据集
x_data = [1.0, 2.0, 3.0]
y_data = [2.0, 4.0, 6.0]
w = 1
# 正向传播
def forward(x):
return x * w
# 计算损失
def cost(xs, ys):
cost = 0
for x, y in zip(xs, ys):
y_pred = forward(x)
cost += (y_pred - y) ** 2
return cost/len(xs)
# 反向传播
def gradient(xs, ys):
grad = 0
for x, y in zip(xs, ys):
grad += 2 * x * (x * w - y)
return grad / len(xs)
# 开始训练
mse_list = []
for epoch in range(100):
# 计算成本
cost_val = cost(x_data, y_data)
# 计算梯度
grad_val = gradient(x_data, y_data)
# 更新参数
w = w - 0.01 * grad_val
mse_list.append(cost_val)
print("epoch=", epoch, "cost_val=", cost_val, "w=", w)
# 预测
print("x=4, y=", forward(4))
# 绘图
plt.plot(mse_list)
plt.xlabel("epoch")
plt.ylabel("cost")
plt.show()
epoch= 0 cost_val= 4.666666666666667 w= 1.0933333333333333
epoch= 1 cost_val= 3.8362074074074086 w= 1.1779555555555554
epoch= 2 cost_val= 3.1535329869958857 w= 1.2546797037037036
epoch= 3 cost_val= 2.592344272332262 w= 1.3242429313580246
...
epoch= 97 cost_val= 2.593287985380858e-08 w= 1.9999324119941766
epoch= 98 cost_val= 2.131797981222471e-08 w= 1.9999387202080534
epoch= 99 cost_val= 1.752432687141379e-08 w= 1.9999444396553017
x=4, y= 7.999777758621207
为了解决在机器学习过程中在遇到“鞍点”(即总体所有点的梯度和为0,导致w=w-0.01*0,w不会改变)而导致不能继续进行的问题。可以采用随机梯度下降,即随机的取一组(x,y)的梯度,作为梯度下降的依据,而不用总体所有点的梯度和,作为梯度下降的依据。实质是使用“噪点”去推动梯度下降。
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
# 数据集
x_data = [1.0, 2.0, 3.0]
y_data = [2.0, 4.0, 6.0]
w = 1
# 正向传播
def forward(x):
return x * w
# 计算损失
def loss(x, y):
cost = 0
y_pred = forward(x)
cost += (y_pred - y) ** 2
return cost
# 反向传播
def gradient(x, y):
grad = 0
grad += 2 * x * (x * w - y)
return grad
# 开始训练
mse_list = []
for epoch in range(100):
# 每次使用一个样本进行训练
for x_val, y_val in zip(x_data, y_data):
# 计算成本
cost_val = loss(x_val, y_val)
# 计算梯度
grad_val = gradient(x_val, y_val)
# 更新参数
w = w - 0.01 * grad_val
mse_list.append(cost_val)
print("epoch=", epoch, "cost_val=", cost_val, "w=", w)
# 预测
print("x=4, y=", forward(4))
# 绘图
plt.plot(mse_list)
plt.xlabel("epoch")
plt.ylabel("cost")
plt.show()
epoch= 0 cost_val= 7.315943039999998 w= 1.260688
epoch= 1 cost_val= 3.9987644858206908 w= 1.453417766656
epoch= 2 cost_val= 2.1856536232765476 w= 1.5959051959019805
epoch= 3 cost_val= 1.1946394387269013 w= 1.701247862192685
...
epoch= 97 cost_val= 2.6081713678869703e-25 w= 1.9999999999998603
epoch= 98 cost_val= 1.4248800100554526e-25 w= 1.9999999999998967
epoch= 99 cost_val= 7.82747233205549e-26 w= 1.9999999999999236
x=4, y= 7.9999999999996945