Pytorch拟合多项式回归

主要思路

多项式f(x)=-1.13s-2.12x^2+3.15x^3-0.01x^4+0.512

输入参数为[x,x^2,x^3,x^4]

需要拟合的参数为[-1.13,-2.14,3.15,-0.11]

所以不需要激活层,只要一个线性层

验证采用留一法

详细代码

#多项式f(x)=-1.13s-2.12x^2+3.15x^3-0.01x^4+0.512
import torch
import numpy as np
import random
import matplotlib.pyplot as plt
x= torch.linspace(-2,2,50)#生成-2到2的50个数构成的等差数列
y=-1.13*x-2.14*torch.pow(x,2)+3.15*torch.pow(x,3)-0.01*torch.pow(x,4)+0.512
plt.scatter(x.data.numpy(),y.data.numpy())
plt.show()

 

def features(x):
    #[x,x^2,x^3,x^4]
    x = x.unsqueeze(1)
    return torch.cat([x ** i for i in range(1,5)],1
    )

x_weights = torch.Tensor([-1.13,-2.14,3.15,-0.11]).unsqueeze(1)
b = torch.Tensor([0.512])
def target(x):
    return x.mm(x_weights)+b.item() #矩阵相乘再加偏置
#随机生成训练数据
def get_batch_data(batch_size):
    batch_x = torch.randn(batch_size)
    #print(batch_x)
    features_x = features(batch_x)
    target_y = target(features_x)
    return features_x,target_y
#建立模型
class PolynomiaRegression(torch.nn.Module):
    def __init__(self):
        super(PolynomiaRegression,self).__init__()
        self.poly = torch.nn.Linear(4,1)
    def forward(self,x):
        return self.poly(x)
#开始训练
epochs = 10000
batch_size = 32
model =PolynomiaRegression()
criterion = torch.nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(),0.001)
for epoch in range(epochs):
    batch_x,batch_y = get_batch_data(batch_size)
    out = model(batch_x)
    loss = criterion(out,batch_y)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    if(epoch%100==0):
        print("Epoch:[{}/{}],loss:[{:.6f}]".format(epoch+1,epochs,loss.item()))
        if(epoch%1000==0):
            predict = model(features(x))
            plt.plot(x.data.numpy(),predict.squeeze(1).data.numpy(),"r")
            loss = criterion(predict,y.unsqueeze(1))
            plt.title("Loss:{:.4f}".format(loss.item()))
            plt.xlabel("X")
            plt.ylabel("Y")
            plt.scatter(x,y)
            plt.show()

Epoch:[1/10000],loss:[78.274490]

Epoch:[101/10000],loss:[3.716179] Epoch:[201/10000],loss:[2.097470] Epoch:[301/10000],loss:[1.362066] Epoch:[401/10000],loss:[1.377455] Epoch:[501/10000],loss:[0.753368] Epoch:[601/10000],loss:[1.218337] Epoch:[701/10000],loss:[0.592332] Epoch:[801/10000],loss:[0.615156] Epoch:[901/10000],loss:[0.674665] Epoch:[1001/10000],loss:[0.421239]

Epoch:[1101/10000],loss:[0.402231] Epoch:[1201/10000],loss:[0.392347] Epoch:[1301/10000],loss:[0.236020] Epoch:[1401/10000],loss:[0.236461] Epoch:[1501/10000],loss:[0.148626] Epoch:[1601/10000],loss:[0.210634] Epoch:[1701/10000],loss:[0.223690] Epoch:[1801/10000],loss:[0.208457] Epoch:[1901/10000],loss:[0.206550] Epoch:[2001/10000],loss:[0.177206]

Epoch:[2101/10000],loss:[0.095194] Epoch:[2201/10000],loss:[0.107039] Epoch:[2301/10000],loss:[0.082668] Epoch:[2401/10000],loss:[0.060180] Epoch:[2501/10000],loss:[0.042087] Epoch:[2601/10000],loss:[0.210911] Epoch:[2701/10000],loss:[0.027396] Epoch:[2801/10000],loss:[0.054213] Epoch:[2901/10000],loss:[0.033912] Epoch:[3001/10000],loss:[0.025427]

Epoch:[3101/10000],loss:[0.026732] Epoch:[3201/10000],loss:[0.019764] Epoch:[3301/10000],loss:[0.069222] Epoch:[3401/10000],loss:[0.020341] Epoch:[3501/10000],loss:[0.012862] Epoch:[3601/10000],loss:[0.012636] Epoch:[3701/10000],loss:[0.186847] Epoch:[3801/10000],loss:[0.007573] Epoch:[3901/10000],loss:[0.008420] Epoch:[4001/10000],loss:[0.005314]

Epoch:[4101/10000],loss:[0.006555] Epoch:[4201/10000],loss:[0.005005] Epoch:[4301/10000],loss:[0.004750] Epoch:[4401/10000],loss:[0.003588] Epoch:[4501/10000],loss:[0.002992] Epoch:[4601/10000],loss:[0.003181] Epoch:[4701/10000],loss:[0.002282] Epoch:[4801/10000],loss:[0.002377] Epoch:[4901/10000],loss:[0.001639] Epoch:[5001/10000],loss:[0.001846]

Epoch:[5101/10000],loss:[0.001734] Epoch:[5201/10000],loss:[0.000998] Epoch:[5301/10000],loss:[0.001507] Epoch:[5401/10000],loss:[0.000973] Epoch:[5501/10000],loss:[0.002113] Epoch:[5601/10000],loss:[0.000688] Epoch:[5701/10000],loss:[0.000762] Epoch:[5801/10000],loss:[0.006292] Epoch:[5901/10000],loss:[0.000619] Epoch:[6001/10000],loss:[0.000485]

Epoch:[6101/10000],loss:[0.001182] Epoch:[6201/10000],loss:[0.000671] Epoch:[6301/10000],loss:[0.000253] Epoch:[6401/10000],loss:[0.000303] Epoch:[6501/10000],loss:[0.000231] Epoch:[6601/10000],loss:[0.000134] Epoch:[6701/10000],loss:[0.000221] Epoch:[6801/10000],loss:[0.000177] Epoch:[6901/10000],loss:[0.000158] Epoch:[7001/10000],loss:[0.000199]

Epoch:[7101/10000],loss:[0.000125] Epoch:[7201/10000],loss:[0.000097] Epoch:[7301/10000],loss:[0.000056] Epoch:[7401/10000],loss:[0.000056] Epoch:[7501/10000],loss:[0.000143] Epoch:[7601/10000],loss:[0.000049] Epoch:[7701/10000],loss:[0.000042] Epoch:[7801/10000],loss:[0.000030] Epoch:[7901/10000],loss:[0.000057] Epoch:[8001/10000],loss:[0.000034]

Epoch:[8101/10000],loss:[0.000037] Epoch:[8201/10000],loss:[0.000024] Epoch:[8301/10000],loss:[0.000023] Epoch:[8401/10000],loss:[0.000026] Epoch:[8501/10000],loss:[0.000027] Epoch:[8601/10000],loss:[0.000025] Epoch:[8701/10000],loss:[0.000014] Epoch:[8801/10000],loss:[0.000015] Epoch:[8901/10000],loss:[0.000010] Epoch:[9001/10000],loss:[0.000008]

Epoch:[9101/10000],loss:[0.000006] Epoch:[9201/10000],loss:[0.000010] Epoch:[9301/10000],loss:[0.000007] Epoch:[9401/10000],loss:[0.000006] Epoch:[9501/10000],loss:[0.000008] Epoch:[9601/10000],loss:[0.000005] Epoch:[9701/10000],loss:[0.000004] Epoch:[9801/10000],loss:[0.000052] Epoch:[9901/10000],loss:[0.000097]

 结果

 从运行结果可以看到4000轮后(400*10)基本稳定。但从测试结果看2000轮之后就开始过拟合了。

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值