我们使用指数函数
y
=
a
+
b
x
+
c
x
3
+
d
x
5
y=a+bx+cx^3+dx^5
y=a+bx+cx3+dx5
来拟合三角函数sin函数
y
=
s
i
n
(
x
)
y=sin(x)
y=sin(x)
由泰勒公式也可知,sin可以展开为指数形式的多项式
代码如下:
dtype = torch.float
device = torch.device("cpu")
x = torch.linspace(-math.pi, math.pi, 2000, device=device, dtype=dtype)
y = torch.sin(x)
a = torch.randn((), device=device, dtype=dtype, requires_grad=True)
b = torch.randn((), device=device, dtype=dtype, requires_grad=True)
c = torch.randn((), device=device, dtype=dtype, requires_grad=True)
d = torch.randn((), device=device, dtype=dtype, requires_grad=True)
learning_rate = 1e-6
for t in range(4000):
y_pred = a * x + b * (x ** 3) + c * (x ** 5) + d * (x ** 7)
loss = (y_pred - y).pow(2).mean()
if t % 100 == 99:
print(t, loss.item())
loss.backward()
with torch.no_grad():
a -= learning_rate * a.grad
b -= learning_rate * b.grad
c -= learning_rate * c.grad
d -= learning_rate * d.grad
a.grad = None
b.grad = None
c.grad = None
d.grad = None
print(f'Result: y = {a.item()} x + {b.item()} x^3 + {c.item()} x^5 + {d.item()} x^7')
print('1/1! = {}'.format(1./1))
print('1/3! = {}'.format(1./(1*2*3)))
print('1/5! = {}'.format(1./(1*2*3*4*5)))
print('1/7! = {}'.format(1./(1*2*3*4*5*6*7)))