1.针对y_hat = w1*x^2+w2*x+b的反向传播计算图构建
2.代码实现
import torch
import numpy as np
import matplotlib.pyplot as plt
#准备数据
x_data = torch.Tensor([1.0,2.0,3.0])
y_data = torch.Tensor([7.0,16.0,29.0])
w1 = torch.Tensor([1.0])#2为最佳
w1.requires_grad = True
w2 = torch.Tensor([2.0])#3为最佳
w2.requires_grad = True
b = torch.Tensor([1.0])#2为最佳
#前馈网络
def forward(x):
return w1*(x**2)+w2*x+b
def loss(x,y):
y_pred = forward(x)
return (y_pred-y)**2
print("Predict (before training)",4,forward(4).item())
#进行训练
for epoch in range(100):
for x,y in zip(x_data,y_data):
l = loss(x,y)
l.backward()
print("\tgrad:",x,y,w1.grad.item(),w2.grad.item())
#进行梯度更新
w1.data -= 0.01 * w1.grad.data
w2.data -= 0.01 * w2.grad.data
#梯度清零
w1.grad.data.zero_()
w2.grad.data.zero_()
print("Progress:",epoch,l.item())
print("Predict (after training)",4,forward(4).item())
#4 对应y值应该为46
3.结果展示
初始预测(Predict (before training) 4 25.0 )VS 最终预测(Predict (after training) 4 45.65895462036133)VS 真实y值(4,46)
D:\Anaconda3\envs\pytorch\python.exe E:/learn_pytorch/LE/Back_homework.py
Predict (before training) 4 25.0
grad: tensor(1.) tensor(7.) -6.0 -6.0
grad: tensor(2.) tensor(16.) -53.12000274658203 -26.560001373291016
grad: tensor(3.) tensor(29.) -120.64320373535156 -40.21440124511719
Progress: 0 44.92216873168945
grad: tensor(1.) tensor(7.) -0.949249267578125 -0.949249267578125
grad: tensor(2.) tensor(16.) 13.623748779296875 6.8118743896484375
grad: tensor(3.) tensor(29.) 72.81600952148438 24.272003173828125
Progress: 1 16.36472511291504
grad: tensor(1.) tensor(7.) -3.261752128601074 -3.261752128601074
grad: tensor(2.) tensor(16.) -17.444747924804688 -8.722373962402344
grad: tensor(3.) tensor(29.) -18.236858367919922 -6.078952789306641
Progress: 2 1.0264908075332642
grad: tensor(1.) tensor(7.) -2.1216230392456055 -2.1216230392456055
grad: tensor(2.) tensor(16.) -2.640045166015625 -1.3200225830078125
grad: tensor(3.) tensor(29.) 24.161853790283203 8.053951263427734
Progress: 3 1.8018369674682617
grad: tensor(1.) tensor(7.) -2.6018733978271484 -2.6018733978271484
grad: tensor(2.) tensor(16.) -9.355560302734375 -4.6777801513671875
grad: tensor(3.) tensor(29.) 3.9725875854492188 1.3241958618164062
Progress: 4 0.04870818555355072
grad: tensor(1.) tensor(7.) -2.3230667114257812 -2.3230667114257812
grad: tensor(2.) tensor(16.) -5.981361389160156 -2.990680694580078
grad: tensor(3.) tensor(29.) 13.144454956054688 4.3814849853515625
Progress: 5 0.5332614183425903
grad: tensor(1.) tensor(7.) -2.4012231826782227 -2.4012231826782227
grad: tensor(2.) tensor(16.) -7.3434906005859375 -3.6717453002929688
grad: tensor(3.) tensor(29.) 8.550315856933594 2.8501052856445312
Progress: 6 0.225641667842865
grad: tensor(1.) tensor(7.) -2.312877655029297 -2.312877655029297
grad: tensor(2.) tensor(16.) -6.488037109375 -3.2440185546875
grad: tensor(3.) tensor(29.) 10.417957305908203 3.4726524353027344
Progress: 7 0.33498096466064453
grad: tensor(1.) tensor(7.) -2.3035335540771484 -2.3035335540771484
grad: tensor(2.) tensor(16.) -6.6764984130859375 -3.3382492065429688
grad: tensor(3.) tensor(29.) 9.25982666015625 3.08660888671875
Progress: 8 0.26464319229125977
grad: tensor(1.) tensor(7.) -2.258026123046875 -2.258026123046875
grad: tensor(2.) tensor(16.) -6.379051208496094 -3.189525604248047
grad: tensor(3.) tensor(29.) 9.525867462158203 3.1752891540527344
Progress: 9 0.28006836771965027
grad: tensor(1.) tensor(7.) -2.230358123779297 -2.230358123779297
grad: tensor(2.) tensor(16.) -6.3131866455078125 -3.1565933227539062
grad: tensor(3.) tensor(29.) 9.128814697265625 3.042938232421875
Progress: 10 0.25720757246017456
grad: tensor(1.) tensor(7.) -2.1951828002929688 -2.1951828002929688
grad: tensor(2.) tensor(16.) -6.142303466796875 -3.0711517333984375
grad: tensor(3.) tensor(29.) 9.047515869140625 3.015838623046875
Progress: 11 0.2526467442512512
grad: tensor(1.) tensor(7.) -2.1643733978271484 -2.1643733978271484
grad: tensor(2.) tensor(16.) -6.02423095703125 -3.012115478515625
grad: tensor(3.) tensor(29.) 8.822845458984375 2.940948486328125
Progress: 12 0.2402549386024475
grad: tensor(1.) tensor(7.) -2.1323471069335938 -2.1323471069335938
grad: tensor(2.) tensor(16.) -5.8848724365234375 -2.9424362182617188
grad: tensor(3.) tensor(29.) 8.669963836669922 2.8899879455566406
Progress: 13 0.23200084269046783
grad: tensor(1.) tensor(7.) -2.1017065048217773 -2.1017065048217773
grad: tensor(2.) tensor(16.) -5.7588958740234375 -2.8794479370117188