pytorch 多元回归

拟合函数y = 0.90 + 0.50 * x + 3.00 * x^2 + 2.40 * x^3

import torch
import numpy as np
import matplotlib.pyplot as plt

#%%
# 定义一个多变量函数
w_target = np.array([0.5, 3, 2.4])
b_target = np.array([0.9]) 

#%%

f_des = 'y = {:.2f} + {:.2f} * x + {:.2f} * x^2 + {:.2f} * x^3'.format(
    b_target[0], w_target[0], w_target[1], w_target[2]) 
print(f_des)
print(f'y = {b_target[0]} + {w_target[0]} * x + {w_target[1]} * x^2 + {w_target[2]} * x^3')
#%%
# 画出函数图像
x_sample = np.arange(-3, 3.1, 0.1)
y_sample = b_target[0] + w_target[0] * x_sample + w_target[1] * x_sample ** 2 +w_target[2] * x_sample ** 3
plt.plot(x_sample, y_sample, label='real curve')
plt.legend()

在这里插入图片描述
构建数据集,需要x和y,同时是一个三次多项式,所以我们取了x,x2,x3

# 构造数据x和y
# x是一个矩阵[[x,x2,x3]]
# y是函数结果 [[y]]
x_train = np.stack([x_sample ** i for i in range(1, 4)], axis=1)
x_train = torch.from_numpy(x_train).float() 
y_train = torch.from_numpy(y_sample).float().unsqueeze(1) 
# 初始化权重 
w = torch.randn((3,1), requires_grad=True)
b = torch.zeros((1), requires_grad=True)
# 定义模型
def multi_linear(x):
    return torch.mm(x, w) + b
# 定义损失函数
def get_loss(y_pred,y):
    return torch.mean((y_pred-y)**2)
    
# 进行100次参数更新
for e in range(100):
    y_pred = multi_linear(x_train)
    loss = get_loss(y_pred, y_train)
    
    w.grad.data.zero_()
    b.grad.data.zero_()
    loss.backward()
    
    # 更新参数
    w.data = w.data - 0.001 * w.grad.data
    b.data = b.data - 0.001 * b.grad.data
    if (e + 1) % 20 == 0:
        print('epoch {}, Loss: {:.5f}'.format(e+1, loss.item()))

# 画出图像
y_pred = multi_linear(x_train)
plt.plot(x_train.data.numpy()[:, 0], y_pred.data.numpy(), label='fitting curve',
color='r')
plt.plot(x_train.data.numpy()[:, 0], y_sample, label='real curve', color='b')
plt.legend()

在这里插入图片描述

可以看到,经过100次更新之后,可以看到拟合的线和真实的线已经完全重合了

  • 0
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值