拟合函数y = 0.90 + 0.50 * x + 3.00 * x^2 + 2.40 * x^3
import torch
import numpy as np
import matplotlib.pyplot as plt
#%%
# 定义一个多变量函数
w_target = np.array([0.5, 3, 2.4])
b_target = np.array([0.9])
#%%
f_des = 'y = {:.2f} + {:.2f} * x + {:.2f} * x^2 + {:.2f} * x^3'.format(
b_target[0], w_target[0], w_target[1], w_target[2])
print(f_des)
print(f'y = {b_target[0]} + {w_target[0]} * x + {w_target[1]} * x^2 + {w_target[2]} * x^3')
#%%
# 画出函数图像
x_sample = np.arange(-3, 3.1, 0.1)
y_sample = b_target[0] + w_target[0] * x_sample + w_target[1] * x_sample ** 2 +w_target[2] * x_sample ** 3
plt.plot(x_sample, y_sample, label='real curve')
plt.legend()
构建数据集,需要x和y,同时是一个三次多项式,所以我们取了x,x2,x3
# 构造数据x和y
# x是一个矩阵[[x,x2,x3]]
# y是函数结果 [[y]]
x_train = np.stack([x_sample ** i for i in range(1, 4)], axis=1)
x_train = torch.from_numpy(x_train).float()
y_train = torch.from_numpy(y_sample).float().unsqueeze(1)
# 初始化权重
w = torch.randn((3,1), requires_grad=True)
b = torch.zeros((1), requires_grad=True)
# 定义模型
def multi_linear(x):
return torch.mm(x, w) + b
# 定义损失函数
def get_loss(y_pred,y):
return torch.mean((y_pred-y)**2)
# 进行100次参数更新
for e in range(100):
y_pred = multi_linear(x_train)
loss = get_loss(y_pred, y_train)
w.grad.data.zero_()
b.grad.data.zero_()
loss.backward()
# 更新参数
w.data = w.data - 0.001 * w.grad.data
b.data = b.data - 0.001 * b.grad.data
if (e + 1) % 20 == 0:
print('epoch {}, Loss: {:.5f}'.format(e+1, loss.item()))
# 画出图像
y_pred = multi_linear(x_train)
plt.plot(x_train.data.numpy()[:, 0], y_pred.data.numpy(), label='fitting curve',
color='r')
plt.plot(x_train.data.numpy()[:, 0], y_sample, label='real curve', color='b')
plt.legend()
可以看到,经过100次更新之后,可以看到拟合的线和真实的线已经完全重合了