pytorch入门(八):梯度下降

3 篇文章 0 订阅
2 篇文章 0 订阅

1、什么是迭代法

解方程时,有时可以得到精确解,有时只能得到近似解。如解方程, x 2 − 4 = 0 x^2-4=0 x24=0可以得到精确解,
x 1 = 2 , x 2 = − 2 x_1=2,x_2=-2 x1=2x2=2如果想知道 2 \sqrt{2} 2 是多少,解方程, x 2 − 2 = 0 x^2-2=0 x22=0就只能求近似解了。迭代法是一种求近似解的方法。它依据迭代公式逐步求解,得到一个收敛于 2 \sqrt{2} 2 的点列 { x n } \{x_n\} {xn},其中的 x n x_n xn就是精确度越来越高的近似解。
首先把方程变形为 x 2 − 1 = 1 x^2-1=1 x21=1于是有 ( x − 1 ) ( x + 1 ) = 1 (x-1)(x+1)=1 (x1)(x+1)=1这就就是 x = 1 + 1 x + 1 x=1+\frac{1}{x+1} x=1+x+11这是一个迭代公式。有了迭代公式,把 x = 1 x=1 x=1 代入等号右边,得到 x = 1.5 x=1.5 x=1.5 ,再把 1.5 1.5 1.5 代入到等号右边,又得到 x = 1.4 x=1.4 x=1.4 ,如此往复,不断执行迭代操作,所得的 x x x 越来越接近于 2 \sqrt{2} 2 ,迭代足够多的次数就到想要的近似值了。

import numpy as np

x = 1
for k in range(10):
    x = 1 + 1/(x+1)
    print('k =', k, ', x =', x)

运行结果

k = 0 , x = 1.5
k = 1 , x = 1.4
k = 2 , x = 1.4166666666666667
k = 3 , x = 1.4137931034482758
k = 4 , x = 1.4142857142857144
k = 5 , x = 1.4142011834319526
k = 6 , x = 1.4142156862745099
k = 7 , x = 1.4142131979695431
k = 8 , x = 1.4142136248948696
k = 9 , x = 1.4142135516460548

增加迭代次数,还能得到更为精确的结果。请思考还能写出别的迭代公式吗?如 x = 2 − 2 x + 2 x=2-\frac{2}{x+2} x=2x+22。实际上,可以写出无数个类似地迭代公式。

2、什么是梯度下降法

想象一下自己如何下山?先看准一个方向,沿着这个方向走一小步,再看准一个方向,再沿着它走一小步,一步接一步,就到山脚了。

梯度下降类似下山的方法用来求函数 f ( x ) f(x) f(x) 的最小值,当然也可以求最大值。它是一种迭代的方法,迭代公式为
x n = x n − 1 − α ∇ f ( x n − 1 ) x_n=x_{n-1}-\alpha \nabla f(x_{n-1}) xn=xn1αf(xn1)其中, ∇ f ( x n − 1 ) \nabla f(x_{n-1}) f(xn1) 是梯度,表示一个方向,$ α \alpha α 是步长。公式表示将点 x n − 1 x_{n-1} xn1沿着梯度方向移动 α \alpha α步长后,就能到达下一个点 x n x_n xn

这个鬼东西是哪来的?回忆下泰勒公式,
f ( x ) ≈ f ( x 0 ) + ∇ f ( x 0 ) Δ x f(x)\approx f(x_0)+\nabla f(x_0) \Delta x f(x)f(x0)+f(x0)Δx x 0 x_0 x0是一个点,从 x 0 x_0 x0跳一步 Δ x \Delta x Δx,就可跳到下一个点 x x x。沿哪个方向?如果选择 ∇ f ( x 0 ) \nabla f(x_0) f(x0) 的反方向 − ∇ f ( x 0 ) -\nabla f(x_0) f(x0) 的话,公式的第二项就成为一个地地道道的负数,它会使得 f ( x ) < f ( x 0 ) f(x)<f(x_0) f(x)<f(x0),不是吗?

下面用梯度下降法解方程 x 3 − 2 x − 3 = 0 x^3-2x-3=0 x32x3=0不是说梯度下降法用来求最小值,它怎么能解方程?现在不是求 g ( x ) = x 3 − 2 x − 3 g(x)=x^3-2x-3 g(x)=x32x3的最小值,那会小于0的。但是我们可以求 f ( x ) = ( x 3 − 2 x − 3 ) 2 f(x)=(x^3-2x-3)^2 f(x)=(x32x3)2 的最小值。如果方程有解的话,最小值就是零,相应的最小值点不就是方程的解吗?

def g(x):
    return x**3-2*x-3

def f(x):
    return g(x)**2

def df(x):
    return 2*(x**3-2*x-3)*(6*x-2)

alpha = 0.001
x = 3
y = f(x)
for k in range(100):
    x = x - alpha * df(x)
    y_new = f(x)
    if k % 10 == 0:
        print('k={}  x={:.5e}  y={:.5e}'.format(k, x, f(x)))
    if abs(y - y_new) < 1e-12 :
        break
    y = y_new

print('g(x) = {:.4e}'.format(g(x)))

先看看结果再做解释

k=0  x=2.42400e+00  y=4.08945e+01
k=10  x=1.94162e+00  y=1.90508e-01
k=20  x=1.90088e+00  y=4.45978e-03
k=30  x=1.89454e+00  y=1.21013e-04
k=40  x=1.89350e+00  y=3.36037e-06
k=50  x=1.89332e+00  y=9.36701e-08
k=60  x=1.89330e+00  y=2.61271e-09
k=70  x=1.89329e+00  y=7.28830e-11
k=80  x=1.89329e+00  y=2.03315e-12
g(x) = 1.4259e-06

f ( x ) f(x) f(x)是目标函数,我们要做的是,看看 x x x等于几时 f ( x ) f(x) f(x)取最小值。如果 f ( x ) f(x) f(x)的最小值等于0,这个几就是方程的解。

d f ( x ) df(x) df(x) f ( x ) f(x) f(x)的导数,可用链式法则算出来。

a l p h a alpha alpha是步长。试试改为更大的值或更小的值,看看有什么事情发生。 x = 3 x=3 x=3 指定了迭代的出发点。这个出发点不能距离真实解太远,否则会出问题。试试把它改为30看看出了什么情况。这好比是说在校门口往西30米处等我,与此同时,对方却在另一座城市,这么做肯定会出问题。

迭代公式为 x = x − a l p h a ∗ d f ( x ) x = x - alpha * df(x) x=xalphadf(x),他要被执行100次(循环100次),每隔10次打印一下中间结果,以免等得心烦。

y y y y _ n e w y\_new y_new 是连续两次计算的函数值。如果两次计算的结果差别不大,再算下去就也就如此了,及时止损。

最后显示 g ( x ) g(x) g(x)的值,它几乎就是零。所以, x = 1.89329 e + 00 x=1.89329e+00 x=1.89329e+00就是近似解。

3、 arg min ⁡ x f ( x ) = 2 x 1 2 + x 2 2 \argmin_{x}f(x)=2x_1^2+x_2^2 xargminf(x)=2x12+x22

目标函数为二元函数, d f df df定义了梯度函数。因为是二元函数,初始点也有两个分量,x = np.array([5, 5])。

xs 和 ys 是两个列表。它们存储了各个中间点,以便最后作图。contour 用于画等高线图。

import matplotlib.pyplot as plt

def f(x):
    return 2*x[0]**2 + x[1]**2

def df(x):
    return np.array((4*x[0], 2*x[1]))

alpha = 0.1
x = np.array([5, 5])
k = 0
xs = []
ys = []
while k < 50:
    k += 1
    x = x - alpha*df(x)
    z = f(x)
    if k % 10 == 0:
        print('k={:2d}, (x,y)=({:.4f},{:.4f}),z={:.6f}'.
              format(k, x[0],x[1], z))
    xs.append(x[0])
    ys.append(x[1])

plt.plot(xs,ys,'r>-')

x = np.linspace(-5,5,20)
y = np.linspace(-5,5,20)
x, y = np.meshgrid(x,y)
z = 2*x**2 + y**2
plt.contour(x, y, z, 
            [2,5,9,14,20,27],
            colors=['#000000']*6
            )

结果如下

k=10, (x,y)=(0.0302,0.5369), z=0.290058
k=20, (x,y)=(0.0002,0.0576), z=0.003323
k=30, (x,y)=(0.0000,0.0062), z=0.000038
k=40, (x,y)=(0.0000,0.0007), z=0.000000
k=50, (x,y)=(0.0000,0.0001), z=0.000000

在这里插入图片描述
跨越了梯度下降,接下来就能搞定线性回归了,下次再见。

  • 1
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值