梯度下降法在凸优化中应用非常广泛,常用于求凸函数极值。梯度是个向量,其形式为
通常是表示函数上升最快的方向!因此,我们只需要每一步往梯度方向走一小步,最终就可以到达极值点,其表现形式为:
初始点为x0, 然后往梯度的反方向移动一小步r到x1, 再次往梯度反方向移动r到x2,... ...,最终会越来越接近极值点min的。
迭代时的公式为X(n+1) = X(n) - r * grad(f)
下面举例子说明梯度下降法求极值点的有效性:
#!/usr/bin/python
# -*- coding:utf8 -*-
import random
import numpy as np
import math
#f(x,y) = (x-2)^2+(y-1)^2 + 1
def solution(grad_func) :
rate = 0.1
x = random.uniform(-10,10)
y = random.uniform(-10,10)
point = np.array([x, y])
for index in xrange(0, 1000) :
grad = grad_func(point[0], point[1])
point = point - rate * grad
print grad
if reduce(lambda a,b: math.sqrt(a*a+b*b), [grad[i] for i in xrange(grad.shape[0])]) < 0.000001 : break
print "times of iterate : %s" % index
return point[0], point[1]
if __name__ == "__main__" :
x, y = solution(lambda a,b: np.array([2*(a-2), 2*(b-1)]))
print "minimum point of f(x,y) = (x-2)^2+(y-1)^2 + 1 : (%s,%s)" % (x, y)
f:\python_workspace\SGD>python gd.py
[ 5.68071667 -21.54721046]
[ 4.54457333 -17.23776836]
[ 3.63565867 -13.79021469]
[ 2.90852693 -11.03217175]
[ 2.32682155 -8.8257374 ]
[ 1.86145724 -7.06058992]
[ 1.48916579 -5.64847194]
[ 1.19133263 -4.51877755]
[ 0.95306611 -3.61502204]
[ 0.76245288 -2.89201763]
[ 0.60996231 -2.31361411]
[ 0.48796985 -1.85089128]
[ 0.39037588 -1.48071303]
[ 0.3123007 -1.18457042]
[ 0.24984056 -0.94765634]
[ 0.19987245 -0.75812507]
[ 0.15989796 -0.60650006]
[ 0.12791837 -0.48520004]
[ 0.10233469 -0.38816004]
[ 0.08186776 -0.31052803]
[ 0.0654942 -0.24842242]
[ 0.05239536 -0.19873794]
[ 0.04191629 -0.15899035]
[ 0.03353303 -0.12719228]
[ 0.02682643 -0.10175382]
[ 0.02146114 -0.08140306]
[ 0.01716891 -0.06512245]
[ 0.01373513 -0.05209796]
[ 0.0109881 -0.04167837]
[ 0.00879048 -0.03334269]
[ 0.00703239 -0.02667415]
[ 0.00562591 -0.02133932]
[ 0.00450073 -0.01707146]
[ 0.00360058 -0.01365717]
[ 0.00288047 -0.01092573]
[ 0.00230437 -0.00874059]
[ 0.0018435 -0.00699247]
[ 0.0014748 -0.00559398]
[ 0.00117984 -0.00447518]
[ 0.00094387 -0.00358014]
[ 0.0007551 -0.00286412]
[ 0.00060408 -0.00229129]
[ 0.00048326 -0.00183303]
[ 0.00038661 -0.00146643]
[ 0.00030929 -0.00117314]
[ 0.00024743 -0.00093851]
[ 0.00019794 -0.00075081]
[ 0.00015836 -0.00060065]
[ 0.00012668 -0.00048052]
[ 0.00010135 -0.00038442]
[ 8.10778975e-05 -3.07532064e-04]
[ 6.48623180e-05 -2.46025651e-04]
[ 5.18898544e-05 -1.96820521e-04]
[ 4.15118835e-05 -1.57456417e-04]
[ 3.32095068e-05 -1.25965133e-04]
[ 2.65676055e-05 -1.00772107e-04]
[ 2.12540844e-05 -8.06176854e-05]
[ 1.70032675e-05 -6.44941483e-05]
[ 1.36026140e-05 -5.15953187e-05]
[ 1.08820912e-05 -4.12762549e-05]
[ 8.70567296e-06 -3.30210039e-05]
[ 6.96453837e-06 -2.64168032e-05]
[ 5.57163069e-06 -2.11334425e-05]
[ 4.45730455e-06 -1.69067540e-05]
[ 3.56584364e-06 -1.35254032e-05]
[ 2.85267491e-06 -1.08203226e-05]
[ 2.28213993e-06 -8.65625806e-06]
[ 1.82571195e-06 -6.92500645e-06]
[ 1.46056956e-06 -5.54000516e-06]
[ 1.16845565e-06 -4.43200413e-06]
[ 9.34764516e-07 -3.54560330e-06]
[ 7.47811614e-07 -2.83648264e-06]
[ 5.98249291e-07 -2.26918611e-06]
[ 4.78599433e-07 -1.81534889e-06]
[ 3.82879547e-07 -1.45227911e-06]
[ 3.06303638e-07 -1.16182329e-06]
[ 2.45042910e-07 -9.29458632e-07]
times of iterate : 76
minimum point of f(x,y) = (x-2)^2+(y-1)^2 + 1 : (2.00000009802,0.999999628217)
可以看到梯度长度慢慢地变小,最终的解与实际解(2, 1)也非常接近!在实际应用中,我们需要对步长r值进行调整,如果步长r太小,那么需要迭代很多次才能收敛,而如果太大,则会越过极值点,一直在极值点附近徘徊。针对这种情况,可以让步长r随着迭代次数的增加而变小。
附上tensorflow计算,代码更简洁,迭代优化过程完全交给tensorflow。
#!/usr/bin/env python
#-*- coding:utf8 -*-
import sys
import tensorflow as tf
import numpy as np
reload(sys)
sys.setdefaultencoding('utf8')
x = tf.Variable(tf.random_uniform([1], 0, 1.0))
y = tf.Variable(tf.random_uniform([1], 0, 1.0))
# f(x,y) = (x-2)^2+(y-1)^2 + 1
fxy = tf.pow((x - 2), 2) + tf.pow((y - 1), 2) + 1
optimizer = tf.train.GradientDescentOptimizer(0.02)
train = optimizer.minimize(fxy)
# 初始化变量
init = tf.initialize_all_variables()
# 启动图 (graph)
sess = tf.Session()
sess.run(init)
for step in xrange(0, 201):
sess.run(train)
if step % 20 == 0:
print step, sess.run(x), sess.run(y)
0 [0.84303737] [0.29856724]
20 [1.4886197] [0.68996495]
40 [1.7739687] [0.8629638]
60 [1.9000934] [0.93942964]
80 [1.9558412] [0.9732277]
100 [1.9804816] [0.9881665]
120 [1.9913727] [0.9947696]
140 [1.9961867] [0.99768823]
160 [1.9983146] [0.99897814]
180 [1.9992551] [0.9995484]
200 [1.9996707] [0.9998003]