利用梯度下降的方法对原函数继续拟合
import torch
import math
import matplotlib.pyplot as plt
class Fitting_polynomial(torch.nn.Module):
def __init__(self):
super(Fitting_polynomial,self).__init__()
self.a = torch.nn.Parameter(torch.randn(()))
self.b = torch.nn.Parameter(torch.randn(()))
self.c = torch.nn.Parameter(torch.randn(()))
self.d = torch.nn.Parameter(torch.randn(()))
def forward(self, x):
y = self.a + self.b * x + self.c * x ** 2 + self.d * x ** 3
return y
def string(self):
"""
Just like any class in Python, you can also define custom method on PyTorch modules
"""
return f'y = {self.a.item()} + {self.b.item()} x + {self.c.item()} x^2 + {self.d.item()} x^3'
def plot_poly(self,x):
fig = plt.figure(figsize=(14,8))
y = self.a + self.b * x + self.c * x ** 2 + self.d * x ** 3
y = y.detach().numpy()
plt.plot(x,y,label="fitting")
plt.legend()
定义原函数(sin);新建模型,并利用MSE作为损失计算的函数
# Create Tensors to hold input and outputs.
x = torch.linspace(-math.pi, math.pi, 1000)
y = torch.sin(x)
# Construct our model by instantiating the class defined above
model = Fitting_polynomial()
criterion = torch.nn.MSELoss(reduction='sum')
optimizer = torch.optim.SGD(model.parameters(), lr=1e-8, momentum=0.9)
训练
for t in range(30000):
# Forward pass: Compute predicted y by passing x to the model
y_pred = model(x)
# Compute and print loss
loss = criterion(y_pred, y)
if t % 2000 == 1999:
print("epoch:{},mse:{}".format(t+1, loss.item()))
print(f'Result: {model.string()}')
plt.plot(x,y,label="raw")
plt.legend()
model.plot_poly(x)
# Zero gradients, perform a backward pass, and update the weights.
optimizer.zero_grad()
loss.backward()
optimizer.step()
epoch:2000,mse:1447.431396484375
Result: y = 0.11133185774087906 + -0.7943404316902161 x + -0.01912527345120907 x^2 + 0.14130832254886627 x^3
epoch:4000,mse:958.48291015625
Result: y = 0.09364601969718933 + -0.4856458902359009 x + -0.016138650476932526 x^2 + 0.09744332730770111 x^3
epoch:6000,mse:635.2223510742188
Result: y = 0.07877693325281143 + -0.23467518389225006 x + -0.013576172292232513 x^2 + 0.06178080663084984 x^3
epoch:8000,mse:421.49969482421875
Result: y = 0.06626877188682556 + -0.03063386306166649 x + -0.01142055168747902 x^2 + 0.03278686851263046 x^3
epoch:10000,mse:280.1954040527344
Result: y = 0.05574660003185272 + 0.13525323569774628 x + -0.009607195854187012 x^2 + 0.009214584715664387 x^3
epoch:12000,mse:186.76930236816406
Result: y = 0.046895161271095276 + 0.27012065052986145 x + -0.008081765845417976 x^2 + -0.009949849918484688 x^3
epoch:14000,mse:124.99755096435547
Result: y = 0.03944912552833557 + 0.37976884841918945 x + -0.0067985402420163155 x^2 + -0.025530679151415825 x^3
epoch:16000,mse:84.1541748046875
Result: y = 0.03318541869521141 + 0.4689137637615204 x + -0.005719069391489029 x^2 + -0.03819802775979042 x^3
epoch:18000,mse:57.14808654785156
Result: y = 0.027916258201003075 + 0.5413891077041626 x + -0.004811000544577837 x^2 + -0.04849664866924286 x^3
epoch:20000,mse:39.29086685180664
Result: y = 0.023483725264668465 + 0.6003121137619019 x + -0.004047111142426729 x^2 + -0.056869521737098694 x^3
epoch:22000,mse:27.48279571533203
Result: y = 0.019754987210035324 + 0.6482172012329102 x + -0.0034045118372887373 x^2 + -0.06367673724889755 x^3
epoch:24000,mse:19.674556732177734
Result: y = 0.016618283465504646 + 0.687164306640625 x + -0.002863941714167595 x^2 + -0.06921107321977615 x^3
epoch:26000,mse:14.511096000671387
Result: y = 0.013979638926684856 + 0.718828558921814 x + -0.0024092060048133135 x^2 + -0.07371050119400024 x^3
epoch:28000,mse:11.096450805664062
Result: y = 0.011759957298636436 + 0.7445719838142395 x + -0.0020266727078706026 x^2 + -0.07736856490373611 x^3
epoch:30000,mse:8.838287353515625
Result: y = 0.009892717935144901 + 0.7655012011528015 x + -0.0017048786394298077 x^2 + -0.0803426131606102 x^3
拟合过程可视化:(部分)