import torch
import matplotlib.pyplot as plt
import numpy as np
#生成一个人工数据集
def create(w,b,sample):
x = torch.normal(0,1,size=(sample,len(w)))
y = w[0]*x + w[1]*(x**2) +w[2] * (x**3) + b
y += torch.normal(0,0.01,y.shape)
return x,y
t_w=torch.tensor([1.2,-3.4,5.6])
t_b=5
plt.Figure()
features,lables=create(t_w,t_b,20)
plt.plot(features,lables,'r+')
plt.show()
w=torch.normal(0,0.01,t_w.shape,requires_grad=True)
b=torch.zeros(1,requires_grad=True)
lr=0.025
def net(x,w,b):
return (w[0]*x + w[1]*(x**2) +w[2] * (x**3) + b)
def sgd(params,lr):
with torch.no_grad():
for param in params:
param -= lr * param.grad
param.grad.zero_()
def Loss(y_hat,y):
return (y_hat-y)**2/2/len(y)
def Train(x,y,w,b,lr):
for epoch in range(100):
y_hat=net(x,w,b)
l=Loss(y_hat,y).sum()
l.backward()
sgd([w,b],lr)
print('loss=%.4f' % l)
Train(features,lables,w,b,lr)
w=w.detach().numpy()
b=b.detach().numpy()
print(w,b)
此时拟合结果
w=[1.37, -3.30, 5.54]
b= 4.86
与要拟合多项式权值
t_w[1.2, -3.4, 5]
t_b=5
接近,当然可以继续设置迭代次数和学习率来进一步减小损失