刘二大人《PyTorch深度学习实践》p3梯度下降算法
一、零碎知识点
1.plt.grid 2D图像画网格线
最简单的就是加一行plt.grid()
,顶多修改一下透明度即可
比较完整的语法格式如下:
from matplotlib import pyplot as plt
x = [1, 2, 3, 4, 5]
y = [1, 4, 9, 16, 25]
plt.plot(x, y)
plt.grid(True, linestyle='dashed', color='gray', alpha=0.5)
plt.show()
linestyle='dashed'
网格为虚线
color='gray'
网格线的颜色为红色
alpha=0.5
网格线的透明度
二、梯度下降法
求梯度gradient=∂cost/∂ω
权重更新ω=ω-α*gradient
from matplotlib import pyplot as plt
x_data = [1.0, 2.0, 3.0]
y_data = [2.0, 4.0, 6.0]
w = 1.0
# 初始权重随机猜测的数
def forward(x):
return x * w
# cost是均值平方差
def cost(xs, ys):
N = len(xs)
cost = 0
for x, y in zip(xs, ys):
y_pred = forward(x)
cost += (y_pred - y) ** 2
return cost / N
def gradient(xs, ys):
N = len(xs)
grad = 0
for x, y in zip(xs, ys):
grad += 2 * x * (x * w - y)
return grad / N
print('Predict (before training)', 4, forward(4))
# 根据初始权重猜测x = 4,y = 4 * 1.0 = 4
epoch_list = []
cost_list = []
for epoch in range(100):
cost_val = cost(x_data, y_data)
grad_val = gradient(x_data, y_data)
w -= 0.01 * grad_val
print(f'Epoch:{epoch},w={w},loss={cost_val}')
epoch_list.append(epoch)
cost_list.append(cost_val)
print('Predict (after training)', 4, forward(4))
plt.plot(epoch_list, cost_list)
plt.xlabel("epoch")
plt.ylabel("cost")
plt.grid(alpha=0.4)
plt.show()
三、随机梯度下降法
随机梯度下降是拿单个样本的损失函数对权重求导,进行更新
from matplotlib import pyplot as plt
x_data = [1.0, 2.0, 3.0]
y_data = [2.0, 4.0, 6.0]
w = 1.0
# 初始权重随机猜测的数
def forward(x):
return x * w
# cost是均值平方差
def loss(x, y):
y_pred = forward(x)
return (y_pred - y) ** 2
def gradient(x, y):
return 2 * x * (x * w - y)
print('Predict (before training)', 4, forward(4))
# 根据初始权重猜测x = 4,y = 4 * 1.0 = 4
epoch_list=[]
w_list=[]
for epoch in range(100):
for x,y in zip(x_data,y_data):
grad = gradient(x,y)
w -= 0.01 * grad
print("\tgrad:",x,y,grad)
l = loss(x,y)
w_list.append(w)
epoch_list.append(epoch)
print(f"progress:{epoch},w={w},loss={l}")
print('Predict (after training)',4,forward(4))
plt.plot(epoch_list,w_list)
plt.ylabel('w')
plt.xlabel('epoch')
plt.grid(alpha=0.4)
plt.show()
四、无关杂记
ps.感觉和研一课程《最优化理论》里面的最速下降法和《数值统计》的均方误差有点关联,学的课程终于不是完全无用了!
顺便安利一下俺的小红书,一个迷茫但努力发光的小白