该算法是为了解决L1正则化下线性回归无法使用梯度下降法求解的问题
先直观的来了解一下坐标轴下降算法
给定二元函数
f
(
x
,
y
)
=
5
x
2
−
6
x
y
+
5
y
2
f(x,y)=5x^2-6xy+5y^2
f(x,y)=5x2−6xy+5y2
如何求解该函数的最小值?(虽然很容易就能看出是0), 坐标轴下降算法可以解决这个问题
- 先来看这个函数的图像
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import numpy as np
x = np.linspace(-1.5,1.5,100)
y = np.linspace(-1.5,1.5,100)
X, Y = np.meshgrid(x, y)
f = 5 * pow(X, 2) - 6 * X * Y + 5 * pow(Y,2)
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
ax.plot_surface(X, Y, f, rstride=1, cstride=1, cmap='rainbow')
plt.show()
很明显这是个凸函数, 函数的最小值应该位于图像最中心的低谷处.
- 再来看看等高线图
x = np.linspace(-1.5, 1.5, 100)
y = np.linspace(-1.5, 1.5, 100)
X, Y = np.meshgrid(x, y)
f = 5 * pow(X, 2) - 6 * X * Y + 5 * pow(Y, 2)
plt.figure(figsize=(10, 12))
figure = plt.contour(X, Y, f, [0.5, 1.4, 2.3, 3.2, 4.1, 5], colors='k')
plt.clabel(figure, fontsize=10, colors="k")
plt.vlines(0, min(y), max(y))
plt.hlines(0, min(x), max(x))
plt.xlabel("x")
plt.ylabel("y")
plt.show()
这样就很明显能看到函数的最小值在哪里了.
- 坐标轴下降算法
def dy(x):
return 6/10*x
def dx(y):
return 6/10*y
min_x = -0.5 #初值
arrowY = [-1.0] #初值
arrowX = [-0.5] #初值
for i in range(1,10):
if i % 2 == 1:
min_y = dy(min_x)
print(min_y)
arrowY.append(min_y)
elif i % 2 == 0:
min_x = dx(min_y)
arrowX.append(min_x)
print(min_x)
x = np.linspace(-1.5, 1.5, 100)
y = np.linspace(-1.5, 1.5, 100)
X, Y = np.meshgrid(x, y)
f = 5 * pow(X, 2) - 6 * X * Y + 5 * pow(Y, 2)
plt.figure(figsize=(10, 12))
figure = plt.contour(X, Y, f, [0.5, 1.4, 2.3, 3.2, 4.1, 5], colors='k')
plt.clabel(figure, fontsize=10, colors="k")
plt.annotate(
'min',
xy=(0, 0),
xytext=(0.15, 0.15),
arrowprops=dict(facecolor='red', shrink=0.01),
)
for i in range(4):
plt.arrow(arrowX[i],
arrowY[i],
0,
arrowY[i+1]-arrowY[i],
width=0.0005,
length_includes_head=True,
ec='r',
head_width=0.02,
fc='r')
for i in range(4):
plt.arrow(arrowX[i],
arrowY[i+1],
arrowX[i+1]-arrowX[i],
0,
width=0.0005,
length_includes_head=True,
ec='r',
head_width=0.02,
fc='r')
plt.vlines(0, min(y), max(y))
plt.hlines(0, min(x), max(x))
plt.xlabel("x")
plt.ylabel("y")
plt.show()
假如我们给定初值
(
−
0.5
,
−
1.0
)
(-0.5,-1.0)
(−0.5,−1.0)然后要求只按照坐标轴(要同时包括
x
、
y
x、y
x、y轴)的方向来下降(迭代)这个数值. 下一次迭代的结果将是
(
−
0.18
,
−
0.3
)
(-0.18,-0.3)
(−0.18,−0.3), 那么为什么会是这个结果?原理之后补上. 从图里可以看到每次迭代的过程都会平行于
x
x
x轴或者
y
y
y轴, 这也是为什么这个算法叫做坐标轴下降算法.
- 验证结果是否正确
显然这个二元函数的最小值为 0 0 0,即 f ( 0 , 0 ) f(0,0) f(0,0), 那么我们只需要观察求解出的 x 、 y x、y x、y坐标是否足够接近 0 0 0即可.
def dy(x):
return 6/10*x
def dx(y):
return 6/10*y
min_x = -0.5
arrowY = [-1.0]
arrowX = [-0.5]
for i in range(1,30):
if i % 2 == 1:
min_y = dy(min_x)
print(min_y)
arrowY.append(min_y)
elif i % 2 == 0:
min_x = dx(min_y)
arrowX.append(min_x)
print(min_x)
将算法循环30次
- 结果
arrowX[-1]
arrowY[-1]
>>> -3.0704711072324055e-07
>>> -1.8422826643394434e-07
显然
x
、
y
x、y
x、y都已经足够接近
0
0
0了, 但为什么
y
y
y会略大于
x
x
x呢?, 这是因为我们执行了偶数次迭代, 所以最后一次迭代是
x
x
x轴方向的下降, 会导致
x
x
x更接近
0
0
0一些。