Python坐标轴下降算法

该算法是为了解决L1正则化下线性回归无法使用梯度下降法求解的问题
先直观的来了解一下坐标轴下降算法
给定二元函数
f ( x , y ) = 5 x 2 − 6 x y + 5 y 2 f(x,y)=5x^2-6xy+5y^2 f(x,y)=5x26xy+5y2
如何求解该函数的最小值?(虽然很容易就能看出是0), 坐标轴下降算法可以解决这个问题

  • 先来看这个函数的图像
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import numpy as np


x = np.linspace(-1.5,1.5,100)
y = np.linspace(-1.5,1.5,100)

X, Y = np.meshgrid(x, y)
f = 5 * pow(X, 2) - 6 * X * Y + 5 * pow(Y,2)
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
ax.plot_surface(X, Y, f, rstride=1, cstride=1, cmap='rainbow')
plt.show()

在这里插入图片描述
很明显这是个凸函数, 函数的最小值应该位于图像最中心的低谷处.

  • 再来看看等高线图
x = np.linspace(-1.5, 1.5, 100)
y = np.linspace(-1.5, 1.5, 100)

X, Y = np.meshgrid(x, y)
f = 5 * pow(X, 2) - 6 * X * Y + 5 * pow(Y, 2)
plt.figure(figsize=(10, 12))
figure = plt.contour(X, Y, f, [0.5, 1.4, 2.3, 3.2, 4.1, 5], colors='k')
plt.clabel(figure, fontsize=10, colors="k")

plt.vlines(0, min(y), max(y))
plt.hlines(0, min(x), max(x))
plt.xlabel("x")
plt.ylabel("y")
plt.show()

在这里插入图片描述
这样就很明显能看到函数的最小值在哪里了.

  • 坐标轴下降算法
def dy(x):
    return 6/10*x

def dx(y):
    return 6/10*y

min_x = -0.5 #初值
arrowY = [-1.0] #初值
arrowX = [-0.5] #初值
for i in range(1,10):
    if i % 2 == 1:
        min_y = dy(min_x)
        print(min_y)
        arrowY.append(min_y)
    elif i % 2 == 0:
        min_x = dx(min_y)
        arrowX.append(min_x)
        print(min_x)
      
x = np.linspace(-1.5, 1.5, 100)
y = np.linspace(-1.5, 1.5, 100)

X, Y = np.meshgrid(x, y)
f = 5 * pow(X, 2) - 6 * X * Y + 5 * pow(Y, 2)
plt.figure(figsize=(10, 12))
figure = plt.contour(X, Y, f, [0.5, 1.4, 2.3, 3.2, 4.1, 5], colors='k')
plt.clabel(figure, fontsize=10, colors="k")
plt.annotate(
    'min',
    xy=(0, 0),
    xytext=(0.15, 0.15),
    arrowprops=dict(facecolor='red', shrink=0.01),
)
for i in range(4):
    plt.arrow(arrowX[i], 
              arrowY[i], 
              0,   
              arrowY[i+1]-arrowY[i],  
              width=0.0005,
              length_includes_head=True,
              ec='r',
              head_width=0.02,
              fc='r')
for i in range(4):
    plt.arrow(arrowX[i], 
              arrowY[i+1], 
              arrowX[i+1]-arrowX[i],   
              0,  
              width=0.0005,
              length_includes_head=True,
              ec='r',
              head_width=0.02,
              fc='r')
plt.vlines(0, min(y), max(y))
plt.hlines(0, min(x), max(x))
plt.xlabel("x")
plt.ylabel("y")
plt.show()

在这里插入图片描述
假如我们给定初值 ( − 0.5 , − 1.0 ) (-0.5,-1.0) (0.5,1.0)然后要求只按照坐标轴(要同时包括 x 、 y x、y xy轴)的方向来下降(迭代)这个数值. 下一次迭代的结果将是 ( − 0.18 , − 0.3 ) (-0.18,-0.3) (0.18,0.3), 那么为什么会是这个结果?原理之后补上. 从图里可以看到每次迭代的过程都会平行于 x x x轴或者 y y y轴, 这也是为什么这个算法叫做坐标轴下降算法.

  • 验证结果是否正确
    显然这个二元函数的最小值为 0 0 0,即 f ( 0 , 0 ) f(0,0) f(0,0), 那么我们只需要观察求解出的 x 、 y x、y xy坐标是否足够接近 0 0 0即可.
def dy(x):
    return 6/10*x

def dx(y):
    return 6/10*y

min_x = -0.5
arrowY = [-1.0]
arrowX = [-0.5]
for i in range(1,30):
    if i % 2 == 1:
        min_y = dy(min_x)
        print(min_y)
        arrowY.append(min_y)
    elif i % 2 == 0:
        min_x = dx(min_y)
        arrowX.append(min_x)
        print(min_x)

将算法循环30次

  • 结果
arrowX[-1]
arrowY[-1]

>>> -3.0704711072324055e-07
>>> -1.8422826643394434e-07

显然 x 、 y x、y xy都已经足够接近 0 0 0了, 但为什么
y y y会略大于 x x x呢?, 这是因为我们执行了偶数次迭代, 所以最后一次迭代是 x x x轴方向的下降, 会导致 x x x更接近 0 0 0一些。

  • 5
    点赞
  • 25
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 7
    评论
评论 7
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Infinity343

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值