2021-06-29

pytorch多项式回归

import torch
import numpy as np
from icecream import ic
from torch.autograd import Variable
import matplotlib.pyplot as plt

w_target = np.array([0.5, 3, 2.4]) # 定义参数
b_target = np.array([0.9]) # 定义参数


x_sample = np.arange(-3, 3.1, 0.1)
y_sample = b_target[0] + w_target[0] * x_sample + w_target[1] * x_sample ** 2 + w_target[2] * x_sample ** 3


x_train= np.stack([x_sample ** i for i in range(1,4)],axis=1)

#训练集
x_train = Variable(torch.Tensor(x_train).float())
y_train = Variable(torch.from_numpy(y_sample).float().unsqueeze(1))

#初始化参数
w = Variable(torch.randn(3, 1), requires_grad=True)
b = Variable(torch.Tensor(1), requires_grad=True)


def y_predict():
    return torch.matmul(x_train, w) + b

def loss_(y_,y):
    return torch.mean((y_ - y) ** 2)



lr = 0.001
#循环迭代
for i in range(40):
    y_ = y_predict()
    loss = loss_(y_, y_train)
    #反向传播
    loss.backward()

    #更新参数
    w.data = w.data - lr * w.grad.data
    b.data = b.data - lr * b.grad.data

    w.grad.data.zero_()
    b.grad.data.zero_()
    if (i+1) % 2 == 0:
        print(f'epoch:{i},loss:{loss.data}')

plt.plot(x_sample, y_sample, label='real curve', color='b')
plt.plot(x_sample, y_predict().data.numpy(), label='fitting curve', color='r')
plt.legend()
plt.show()

在这里插入图片描述

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值