使用梯度下降法模拟预测线性函数y=ax^2+b,代码为python
import torch
import matplotlib.pyplot as plt
torch.manual_seed(1314)
t_weight,t_bais = 1,10
weight,bais = torch.rand(2)
x_train = torch.linspace(0,5,1000)
y_train = t_weight*x_train**2 + t_bais + torch.randn(1000)*0.01
plt.plot(x_train,y_train)
plt.show()
lr = 0.001
print("true weight:%d true bais:%d"%(t_weight,t_bais))
print("origin weight:%.4f origin bais:%.4f"%(weight,bais))
min_epoch = -1
mini_loss = torch.inf
pred_weight = 0
pred_bais = 0
loss = 0
nepoch=100
for epoch in range(nepoch):
for index in range(len(x_train)):
pred = weight*x_train[index]**2+bais
loss = 1/2*(y_train[index]-pred)**2
grad_weight = -(y_train[index]-pred)*(x_train[index])**2
grad_bais = -(y_train[index]-pred)
weight-=lr*grad_weight
bais-=lr*grad_bais
if loss < mini_loss:
mini_loss = loss
min_epoch = epoch
pred_weight = weight
pred_bais = bais
print("loss:%5.5f min_loss:%5.5f"%(loss,mini_loss))
print("epoch:%5d min_epoch%5d"%(epoch,min_epoch))
print("best-> weight:%5.3f bais:%5.3f"%(pred_weight,pred_bais))
print("now--> weight:%5.3f bais:%5.3f"%(weight,bais))
print("----------------------------------------")