Pytorch拟合心形函数

本文通过Python的PyTorch库,展示了一个使用线性层拟合多项式函数f(x)=13cosx-5cos2x-2cos3x-cos4x的案例。作者采用留一法验证,训练了100000轮,并观察了训练过程中的损失变化。最终模型能够较好地拟合给定的函数,但可能还需更多轮次以达到更优状态。
摘要由CSDN通过智能技术生成

前言(情人节对象不够?AI来凑!

 这不到了情人节,还没有对象一起过节。就拉着AI来过。

最开始用10000(1万)轮训练,我差点以为AI差点跟我说“你想吃peach”

后来给她加个0,才听话。

废话不多说,看正文。

主要思路

多项式f(x)=13cosx-5cos2x-2cos3x-cos4x

不同于前一篇,这次来个参数方程,但我们只拟合y

输入参数为[cosx,cos2x,cos3x,cos4x]

需要拟合的参数为[13,-5,-2,-1]

所以不需要激活层,只要一个线性层

验证采用留一法

一共训练100000轮,只能说拟合得太慢

详细代码

#多项式f(x)=13cosx-5cos2x-2cos3x-cos4x
import torch
import numpy as np
import random
import matplotlib.pyplot as plt
t = torch.linspace(-15,15,1000)#生成-15到15的1000个数构成的等差数列
x = 16*torch.sin(t)**3
y = 13*torch.cos(t)-5*torch.cos(2*t)-2*torch.cos(3*t)-torch.cos(4*t)
plt.scatter(x.data.numpy(),y.data.numpy())
plt.show()

 

def y_features(t):
    #[cosx,cos2x,cos3x,cos4x]
    t = t.unsqueeze(1)
    return torch.cat([torch.cos(t * i) for i in range(1,5)],1)
def x_features(t):
    t = t.unsqueeze(1)
    return 16*torch.sin(t)**3

t_weights = torch.Tensor([13,-5,-2,-1]).unsqueeze(1)
def target(t):
    return t.mm(t_weights) #矩阵相乘
#随机生成训练数据
def get_batch_data(batch_size):
    batch_x = torch.randn(batch_size)
    #print(batch_x)
    features_x = x_features(batch_x)
    features_y = y_features(batch_x)
    target_x = features_x
    target_y = target(features_y)
    return target_x,features_y,target_y
#建立模型
class PolynomiaRegression(torch.nn.Module):
    def __init__(self):
        super(PolynomiaRegression,self).__init__()
        self.poly = torch.nn.Linear(4,1)
    def forward(self,t):
        return self.poly(t)
#开始训练
import math
epochs = 100000
batch_size = 32
model =PolynomiaRegression()
criterion = torch.nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(),0.001)
loss_value = np.inf
loss_holder = []
step = 0
for epoch in range(epochs):
    target_x,batch_x,batch_y = get_batch_data(batch_size)
    out = model(batch_x)
    loss = criterion(out,batch_y)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    if(loss<loss_value):
        torch.save(model,'model.ckpt') #存模型
        loss_value = loss
    if(epoch%100==0):
        step+=1
        loss_holder.append([step,math.sqrt(loss/batch_size)])
        if(epoch%1000==0):
            print("Epoch:[{}/{}],loss:[{:.6f}]".format(epoch+1,epochs,loss.item()))
            if(epoch%10000==0):
                predict = model(y_features(t))
                plt.plot(x.data.numpy(),predict.squeeze(1).data.numpy(),"r")
                loss = criterion(predict,y.unsqueeze(1))
                plt.title("Loss:{:.4f}".format(loss.item()))
                plt.xlabel("X")
                plt.ylabel("Y")
                plt.scatter(x,y)
                plt.show()

Epoch:[1/100000],loss:[65.298500]

Epoch:[1001/100000],loss:[2.491852] Epoch:[2001/100000],loss:[2.300167] Epoch:[3001/100000],loss:[0.070248] Epoch:[4001/100000],loss:[13.497649] Epoch:[5001/100000],loss:[0.566398] Epoch:[6001/100000],loss:[3.217979] Epoch:[7001/100000],loss:[0.192295] Epoch:[8001/100000],loss:[2.993185] Epoch:[9001/100000],loss:[0.840189] Epoch:[10001/100000],loss:[0.044125]

Epoch:[11001/100000],loss:[0.030104] Epoch:[12001/100000],loss:[0.057011] Epoch:[13001/100000],loss:[0.652313] Epoch:[14001/100000],loss:[0.055101] Epoch:[15001/100000],loss:[0.035360] Epoch:[16001/100000],loss:[0.404017] Epoch:[17001/100000],loss:[0.032413] Epoch:[18001/100000],loss:[0.026495] Epoch:[19001/100000],loss:[0.258574] Epoch:[20001/100000],loss:[0.013485]

Epoch:[21001/100000],loss:[0.030264] Epoch:[22001/100000],loss:[1.169118] Epoch:[23001/100000],loss:[1.099216] Epoch:[24001/100000],loss:[0.470649] Epoch:[25001/100000],loss:[0.115718] Epoch:[26001/100000],loss:[0.105106] Epoch:[27001/100000],loss:[0.365399] Epoch:[28001/100000],loss:[0.002508] Epoch:[29001/100000],loss:[0.008796] Epoch:[30001/100000],loss:[0.333936]

Epoch:[31001/100000],loss:[0.007649] Epoch:[32001/100000],loss:[0.012496] Epoch:[33001/100000],loss:[0.006446] Epoch:[34001/100000],loss:[0.060192] Epoch:[35001/100000],loss:[0.005069] Epoch:[36001/100000],loss:[0.003233] Epoch:[37001/100000],loss:[0.003273] Epoch:[38001/100000],loss:[0.001058] Epoch:[39001/100000],loss:[0.001612] Epoch:[40001/100000],loss:[0.037987]

Epoch:[41001/100000],loss:[0.081242] Epoch:[42001/100000],loss:[0.008152] Epoch:[43001/100000],loss:[0.002064] Epoch:[44001/100000],loss:[0.001179] Epoch:[45001/100000],loss:[0.001132] Epoch:[46001/100000],loss:[0.003256] Epoch:[47001/100000],loss:[0.001605] Epoch:[48001/100000],loss:[0.021894] Epoch:[49001/100000],loss:[0.017690] Epoch:[50001/100000],loss:[0.001353]

Epoch:[51001/100000],loss:[0.000576] Epoch:[52001/100000],loss:[0.000605] Epoch:[53001/100000],loss:[0.000519] Epoch:[54001/100000],loss:[0.002327] Epoch:[55001/100000],loss:[0.000258] Epoch:[56001/100000],loss:[0.000120] Epoch:[57001/100000],loss:[0.013316] Epoch:[58001/100000],loss:[0.000190] Epoch:[59001/100000],loss:[0.000207] Epoch:[60001/100000],loss:[0.000129]

Epoch:[61001/100000],loss:[0.000184] Epoch:[62001/100000],loss:[0.001675] Epoch:[63001/100000],loss:[0.000082] Epoch:[64001/100000],loss:[0.002057] Epoch:[65001/100000],loss:[0.002364] Epoch:[66001/100000],loss:[0.000843] Epoch:[67001/100000],loss:[0.000024] Epoch:[68001/100000],loss:[0.000076] Epoch:[69001/100000],loss:[0.000077] Epoch:[70001/100000],loss:[0.000027]

Epoch:[71001/100000],loss:[0.000503] Epoch:[72001/100000],loss:[0.001702] Epoch:[73001/100000],loss:[0.001267] Epoch:[74001/100000],loss:[0.000023] Epoch:[75001/100000],loss:[0.000828] Epoch:[76001/100000],loss:[0.000147] Epoch:[77001/100000],loss:[0.000082] Epoch:[78001/100000],loss:[0.000008] Epoch:[79001/100000],loss:[0.000137] Epoch:[80001/100000],loss:[0.000009]

Epoch:[81001/100000],loss:[0.000667] Epoch:[82001/100000],loss:[0.000015] Epoch:[83001/100000],loss:[0.000071] Epoch:[84001/100000],loss:[0.000247] Epoch:[85001/100000],loss:[0.000227] Epoch:[86001/100000],loss:[0.000258] Epoch:[87001/100000],loss:[0.000006] Epoch:[88001/100000],loss:[0.000004] Epoch:[89001/100000],loss:[0.000016] Epoch:[90001/100000],loss:[0.000054]

Epoch:[91001/100000],loss:[0.000018] Epoch:[92001/100000],loss:[0.000002] Epoch:[93001/100000],loss:[0.000001] Epoch:[94001/100000],loss:[0.000003] Epoch:[95001/100000],loss:[0.000048] Epoch:[96001/100000],loss:[0.000061] Epoch:[97001/100000],loss:[0.000076] Epoch:[98001/100000],loss:[0.000001] Epoch:[99001/100000],loss:[0.000015]

 结果

 从运行结果可以看到98100轮后(9810*10)基本稳定。但从测试结果看出100000轮都还没过拟合。。。看来还需要多几轮训练。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值