import torch as t
from matplotlib import pyplot as plt
import numpy as np
t.manual_seed(1000) # 随机种子
def get_fake_data(batch_size=8): # 产生噪声数据
x = t.rand(batch_size, 1) * 20
noise = t.randn(batch_size, 1)
y = x * 2 + (1 + noise) * 3 # y = 2x + 3
# print("noise", noise)
return x, y
w = t.rand(1, 1)
b = t.zeros(1, 1)
lr = 0.001 # 学习率
print("w:", w, "b:", b)
plt.ion()
for ii in range(20000):
x, y = get_fake_data()
# 前向传播
y_pred = w*x + b
# 计算loss
loss = 0.5 * (y_pred - y) ** 2
loss = loss.sum()
# 手动反向传播
dloss = 1
dy_pred = dloss * (y_pred - y)
dw = x.t().mm(dy_pred)
db = dy_pred.sum()
# 参数更新
w.sub_(lr * dw)
b.sub_(lr * db)
# print(ii, loss, w, b)
if ii % 100 == 0: # 过程可视化
a = np.arange(25)
plt.clf() # 清除之前画的图
plt.scatter(x.squeeze().numpy(), y.squeeze().numpy())
plt.plot(a, w.item()*a + b.item()) # predict
plt.plot(a, 2*a + 3) # ground truth
plt.xlim(0, 20)
plt.ylim(0, 41)
plt.pause(0.01)
plt.ioff() # 关闭画图的窗口
print("w:", w.squeeze().item(), "b:", b.squeeze().item()) # 最终的 w b
上面是自己手动算梯度反向传播的,下面这个是使用autograd实现的反向传播
import torch as t
from matplotlib import pyplot as plt
import numpy as np
from torch.autograd import Variable as V
t.manual_seed(1000) # 随机种子
def get_fake_data(batch_size=8): # 产生噪声数据
x = t.rand(batch_size, 1) * 20
noise = t.randn(batch_size, 1)
y = x * 2 + (1 + noise) * 3 # y = 2x + 3
# print("noise", noise)
return x, y
# w = t.rand(1, 1)
# b = t.zeros(1, 1)
w = V(t.rand(1, 1), requires_grad=True)
b = V(t.zeros(1, 1), requires_grad=True)
lr = 0.001 # 学习率
print("w:", w, "b:", b)
plt.ion()
for ii in range(20000):
x, y = get_fake_data()
x, y = V(x), V(y)
# 前向传播
y_pred = w*x + b
# 计算loss
loss = 0.5 * (y_pred - y) ** 2
loss = loss.sum()
# 手动反向传播
# dloss = 1
# dy_pred = dloss * (y_pred - y)
# dw = x.t().mm(dy_pred)
# db = dy_pred.sum()
# 自动反向传播
loss.backward()
# 参数更新
# w.sub_(lr * dw)
# b.sub_(lr * db)
w.data.sub_(lr * w.grad.data)
b.data.sub_(lr * b.grad.data)
# print(ii, loss, w, b)
# 梯度清零
w.grad.data.zero_()
b.grad.data.zero_()
if ii % 100 == 0: # 过程可视化
a = np.arange(25)
plt.clf() # 清除之前画的图
plt.scatter(x.squeeze().numpy(), y.squeeze().numpy())
plt.plot(a, w.item()*a + b.item()) # predict
plt.plot(a, 2*a + 3) # ground truth
plt.xlim(0, 20)
plt.ylim(0, 41)
plt.pause(0.01)
plt.ioff() # 关闭画图的窗口
print("w:", w.squeeze().item(), "b:", b.squeeze().item()) # 最终的 w b