关于Proximal Methods,近端梯度下降的理解

本文详细介绍了Proximal Methods,包括子梯度和Proximal Operator的概念,以及它们在解决不可导损失函数最小化问题中的应用。通过不动点迭代和泰勒级数展开两种方法证明了ProximalMethods的有效性,特别是对于包含L1正则化的机器学习问题。此外,还展示了ProximalMethods在逻辑回归模型中的实际应用和迭代过程。
摘要由CSDN通过智能技术生成

本文介绍了两种Proximal Methods的证明方法以及实现。内容主要来源于王然老师的《Proximal Methods》一文以及网络,加入了部分个人理解。由于水平有限,如有不妥之处,敬请指正。

为什么会有Proximal methods这个东东?

在机器学习的损失函数求解过程中,通过计算梯度然后迭代寻找最小值是一个常用的方法。而对于一些函数,是无法求导的,这时就无法用梯度下降等方法求解了。比如加了 L 1 L1 L1正则的损失函数。

a r g m i n β 1 N ∗ ∑ i ( y i − x i ∗ β t ) + λ ∗ ∥ β ∥ 1 \mathop{\mathrm{argmin}} \limits_{\beta} \frac{1}{N}*\sum_i(y_i-x_i*\beta^t)+\lambda*\Vert \beta \Vert_1 βargminN1i(yixiβt)+λβ1

proximal methods主要就是解决这个问题的。

proximal methods证明前的铺垫

主要介绍sub-differential和proximal operator这两个概念,后面证明时会用到。

sub-differential 子梯度

先介绍一个概念,sub-differential 子梯度,也叫:subderivative, subgradient, and subdifferential,是对于不可导的凸函数的导数的一种推广。
比如,对于绝对值函数 f ( x ) = ∣ x ∣ f(x)=\vert x \vert f(x)=x,当 x = 0 x=0 x=0时,函数是不可导的。
如下图,对于 x 0 x_0 x0点不可导(类似绝对值函数),但是我们可以在点 ( x 0 , f ( x 0 ) ) (x_0,f(x_0)) (x0,f(x0))上画一条线,这条线经过 x 0 x_0 x0点,并且在曲线的下方,像这样的曲线的斜率就是sub-differential中的一个。
在这里插入图片描述
子梯度的严格定义:
对于凸函数 f : I → R f:I \to \mathbb{R} f:IR x 0 x_0 x0的子梯度是一个实数 c c c c c c满足以下条件:
f ( x ) − f ( x 0 ) ≥ c ( x − x 0 ) f(x)-f(x_0)\geq c(x-x_0) f(x)f(x0)c(xx0)
对于所有在 I I I内的 x x x,在 x 0 x_0 x0的子梯度是一个非空的闭区间集合 [ a , b ] [a,b] [a,b],其中:
a = lim ⁡ x → x 0 − f ( x ) − f ( x 0 ) x − x 0 a=\lim_{x \to x_0^-} \frac{f(x)-f(x_0)}{x-x_0} a=limxx0xx0f(x)f(x0)

b = lim ⁡ x → x 0 + f ( x ) − f ( x 0 ) x − x 0 b=\lim_{x \to x_0^+} \frac{f(x)-f(x_0)}{x-x_0} b=limxx0+xx0f(x)f(x0)

sub-differential记为 ∂ f \partial f f,有:

∂ f = { y ∣ f ( x ) − f ( x 0 ) ≥ y T ( x − x 0 ) , f o r   a l l   x ∈ d o m   f } \partial f = \{ y | f(x)-f(x_0)\geq y^T(x-x_0), for \ all \ x \in dom \ f\} f={yf(x)f(x0)yT(xx0),for all xdom f}

性质:
1、当一个凸函数在 x 0 x_0 x0处的子梯度只有一个值,即 a = b a=b a=b时,函数在这个点可导。
2、如果一个凸函数在 x 0 x_0 x0处的子梯度集合为 [ a , b ] [a,b] [a,b],当 0 ∈ [ a , b ] 0 \in [a,b] 0[a,b]时,函数在 x 0 x_0 x0处取得最小值。
3、如果 f , g f,g f,g两个函数都是凸函数,则:
∂ ( f + g ) = ∂ f + ∂ g \partial(f+g)=\partial f + \partial g (f+g)=f+g

另外,维基百科上说,国内的部分机构认为的凸函数的定义与国外的正好相反,不过本文并不想纠结于这个问题。

详见:
https://en.wikipedia.org/wiki/Subderivative

Proximal Operator

还要介绍一个概念。Proximal操作算子:

p r o v f ( v ) = a r g m i n x ( f ( x ) + 1 2 ∗ ∥ x − v ∥ 2 2 ) prov_f(v)=\mathop{\mathrm{argmin}} \limits_x (f(x) + \frac{1}{2}*\Vert x-v \Vert_2^2) provf(v)=xargmin(f(x)+21xv22)

Proximal Operator有两个神奇的性质,一是不动点,二是proximal operator和sub-differential之间有一定的关系。

  • 性质一:不动点

x ∗ x^* x f ( x ) f(x) f(x)的最小值时,等价于:
x ∗ = p r o v f ( x ∗ ) x^*=prov_f(x^*) x=provf(x)
证明:
首先证明: x ∗ x^* x f ( x ) f(x) f(x)的最小值时, x ∗ = p r o v f ( x ∗ ) x^*=prov_f(x^*) x=provf(x)
f ( x ) + 1 2 ∗ ∥ x − x ∗ ∥ 2 2 ≥ f ( x ∗ ) = f ( x ∗ ) + 1 2 ∥ x ∗ − x ∗ ∥ 2 2 \begin{aligned} f(x) + \frac{1}{2}* \Vert x-x^* \Vert^2_2 & \geq f(x^*) \\ &=f(x^*)+\frac{1}{2} \Vert x^*-x^*\Vert^2_2 \\ \end{aligned} f(x)+21xx22f(x)=f(x)+21xx22
f ( x ) + 1 2 ∗ ∥ x − x ∗ ∥ 2 2 f(x) + \frac{1}{2}* \Vert x-x^* \Vert^2_2 f(x)+21xx22 x = x ∗ x=x^* x=x处取得最小值,即: x ∗ = a r g m i n x ( f ( x ) + 1 2 ∗ ∥ x − x ∗ ∥ 2 2 ) x^*=\mathop{\mathrm{argmin}} \limits_x (f(x) + \frac{1}{2}* \Vert x-x^* \Vert^2_2) x=xargmin(f(x)+21xx22),也就是 x ∗ = p r o v f ( x ∗ ) x^*=prov_f(x^*) x=provf(x)啦。

再证明:当 x ∗ = p r o v f ( x ∗ ) x^*=prov_f(x^*) x=provf(x)时, x ∗ x^* x f ( x ) f(x) f(x)的最小值。

x ∗ = p r o v f ( x ∗ ) x^*=prov_f(x^*) x=provf(x),根据sub-differential的性质,有:
0 ∈ ∂ ( p r o v f ( x ∗ ) ) 0 ∈ ∂ ( f ( x ) + 1 2 ∥ x − x ∗ ∥ 2 2 ) 0 ∈ ∂ f ( x ) + ( x − x ∗ ) 令 x = x ∗ , 则 有 : 0 ∈ ∂ f ( x ∗ ) \begin{aligned} 0 &\in \partial (prov_f(x^*)) \\ 0 &\in \partial (f(x)+\frac{1}{2}\Vert x-x^*\Vert^2_2)\\ 0 &\in \partial f(x) +(x-x^*)\\ 令x&=x^*,则有:\\ 0 &\in \partial f(x^*) \end{aligned} 000x0(provf(x))(f(x)+21xx22)f(x)+(xx)=xf(x)
x ∗ x^* x f ( x ) f(x) f(x)的最小值。

  • 性质二:proximal operator实际上是sub-differential的一种解析形式,有:
    p r o v λ f = ( I + λ ∂ f ) − 1 \begin{aligned} prov_{\lambda f}=(I+\lambda \partial f)^{-1} \end{aligned} provλf=(I+λf)1
    说明: p r o v λ f prov_{\lambda f} provλf ( I + λ ∂ f ) − 1 (I+\lambda \partial f)^{-1} (I+λf)1都是操作算子, p r o v λ f ( v ) = a r g m i n x ( f ( x ) + 1 2 λ ∥ x − v ∥ 2 2 ) prov_{\lambda f}(v)=\mathop{\mathrm{argmin}} \limits_x(f(x)+\frac{1}{2\lambda}\Vert x-v \Vert_2^2) provλf(v)=xargmin(f(x)+2λ1xv22), ( I + λ ∂ f ) − 1 (I+\lambda \partial f)^{-1} (I+λf)1 ( I + λ ∂ f ) (I+\lambda \partial f) (I+λf)的反函数。
    证明:
    如果:
    z ∈ ( I + λ ∂ f ) − 1 ( x ) ( I + λ ∂ f ) ( z ) ∋ x z + λ ∂ f ( z ) ∋ x 0 ∈ λ ∂ f ( z ) + ( z − x ) 0 ∈ ∂ ( λ f ( z ) + 1 2 ∥ z − x ∥ 2 2 ) 0 ∈ ∂ ( f ( z ) + 1 2 λ ∥ z − x ∥ 2 2 ) 即 : z = a r g m i n u ( f ( u ) + 1 2 λ ∥ u − x ∥ 2 2 ) \begin{aligned} z &\in (I+\lambda \partial f)^{-1}(x)\\ (I+\lambda \partial f)(z) &\ni x\\ z+\lambda \partial f(z) &\ni x\\ 0 &\in \lambda \partial f(z)+(z-x)\\ 0 &\in \partial(\lambda f(z)+ \frac{1}{2}\Vert z-x\Vert_2^2)\\ 0 &\in \partial( f(z)+ \frac{1}{2\lambda}\Vert z-x\Vert_2^2)\\ 即:\\ z&=\mathop{\mathrm{argmin}} \limits_u (f(u)+\frac{1}{2\lambda} \Vert u-x\Vert^2_2)\\ \end{aligned} z(I+λf)(z)z+λf(z)000z(I+λf)1(x)xxλf(z)+(zx)(λf(z)+21zx22)(f(z)+2λ1zx22)=uargmin(f(u)+2λ1ux22)
    即: ( f ( u ) + 1 2 λ ∥ u − x ∥ 2 2 ) ( f(u)+ \frac{1}{2\lambda}\Vert u-x\Vert_2^2) (f(u)+2λ1ux22) z z z处取得最小值, z = p r o v λ f ( x ) z=prov_{\lambda f}(x) z=provλf(x),注意这里的 x x x其实是前面的 v v v
    这里有点儿神奇,当 z ∈ ( I + λ ∂ f ) − 1 ( x ) z \in (I+\lambda \partial f)^{-1}(x) z(I+λf)1(x)时, z = p r o v λ f ( x ) z=prov_{\lambda f}(x) z=provλf(x)
    两个看起来没什么关系的东西竟然也能联系在一起。。。

Proximal Methods的求解证明

文章的开头,我们就提出了一个问题:对于两个函数 f + g f+g f+g,当 f f f可导,但 g g g不可导时,如何求解最小值呢?

我们先给出答案,再对其进行证明。
通过以下迭代,能够计算出 f + g f+g f+g的最小值。

x k + 1 = p r o v λ k g ( x k − λ k ∇ f ( x k ) ) x^{k+1}=prov_{\lambda^k g}(x^k-\lambda^k \nabla f(x^k)) xk+1=provλkg(xkλkf(xk))

  • 证明方法一
    如果 x ∗ x^* x f + g f+g f+g的最小值,则有 0 ∈ ∇ f ( x ∗ ) + ∂ g ( x ∗ ) 0 \in \nabla f(x^*)+ \partial g(x^*) 0f(x)+g(x)
    0 ∈ λ ∇ f ( x ∗ ) + λ ∂ g ( x ∗ ) 0 ∈ λ ∇ f ( x ∗ ) − x ∗ + x ∗ + λ ∂ g ( x ∗ ) 0 ∈ λ ∇ f ( x ∗ ) − x ∗ + ( I + λ ∂ g ) ( x ∗ ) ( I + λ ∂ g ) ( x ∗ ) ∋ x ∗ − λ ∇ f ( x ∗ ) x ∗ = ( I + λ ∂ g ) − 1 ( x ∗ − λ ∇ f ( x ∗ ) ) x ∗ = p r o v λ g ( x ∗ − λ ∇ f ( x ∗ ) ) \begin{aligned} 0& \in \lambda \nabla f(x^*)+ \lambda \partial g(x^*) \\ 0& \in \lambda \nabla f(x^*)- x^* + x^* + \lambda \partial g(x^*)\\ 0& \in \lambda \nabla f(x^*)-x^* + (I+ \lambda \partial g)(x^*)\\ (I+ \lambda \partial g)(x^*) &\ni x^*-\lambda \nabla f(x^*)\\ x^* &= (I+\lambda \partial g)^{-1}(x^*-\lambda \nabla f(x^*))\\ x^* &= prov_{\lambda g}(x^*-\lambda \nabla f(x^*)) \end{aligned} 000(I+λg)(x)xxλf(x)+λg(x)λf(x)x+x+λg(x)λf(x)x+(I+λg)(x)xλf(x)=(I+λg)1(xλf(x))=provλg(xλf(x))
    这个证明过程也是很神奇的。。。

  • 证明方法二

x k + 1 = p r o v λ k g ( x k − λ k ∇ f ( x k ) ) x k + 1 = a r g m i n x ( g ( x ) + 1 2 λ k ∥ x − ( x k − λ k ∇ f ( x k ) ) ∥ 2 2 ) x k + 1 = a r g m i n x ( g ( x ) + λ k 2 ∥ ∇ f ( x k ) ∥ 2 2 + ∇ f ( x k ) T ( x − x k ) + 1 2 λ k ∥ x − x k ∥ 2 2 ) \begin{aligned} x^{k+1}&=prov_{\lambda^k g}(x^k-\lambda^k \nabla f(x^k))\\ x^{k+1}&=\mathop{\mathrm{argmin}} \limits_{x} (g(x) +\frac{1}{2\lambda^k}\Vert x-(x^k-\lambda^k \nabla f(x^k))\Vert^2_2)\\ x^{k+1}&=\mathop{\mathrm{argmin}} \limits_{x} (g(x) +\frac{\lambda^k}{2}\Vert \nabla f(x^k) \Vert^2_2 + \nabla f(x^k)^T (x-x^k)+\frac{1}{2 \lambda^k}\Vert x-x^k\Vert^2_2)\\ \end{aligned} xk+1xk+1xk+1=provλkg(xkλkf(xk))=xargmin(g(x)+2λk1x(xkλkf(xk))22)=xargmin(g(x)+2λkf(xk)22+f(xk)T(xxk)+2λk1xxk22)
由于上式是对于 x x x求最小值,而 λ k 2 ∥ ∇ f ( x k ) ∥ 2 2 \frac{\lambda^k}{2}\Vert \nabla f(x^k) \Vert^2_2 2λkf(xk)22是一个与 x x x无关的常量,则可将其替换为 f ( x k ) f(x^k) f(xk),则上式等价于:
x k + 1 = a r g m i n x ( g ( x ) + f ( x k ) + ∇ f ( x k ) T ( x − x k ) + 1 2 λ k ∥ x − x k ∥ 2 2 ) \begin{aligned} x^{k+1}=\mathop{\mathrm{argmin}} \limits_{x} (g(x) +f(x^k) + \nabla f(x^k)^T (x-x^k)+\frac{1}{2 \lambda^k}\Vert x-x^k\Vert^2_2) \end{aligned} xk+1=xargmin(g(x)+f(xk)+f(xk)T(xxk)+2λk1xxk22)
根据泰勒级数展开:
f ( x ) = f ( x k ) + ∇ f ( x k ) T ( x − x k ) + 1 2 λ k ∥ x − x k ∥ 2 2 \begin{aligned} f(x)=f(x^k) + \nabla f(x^k)^T (x-x^k)+\frac{1}{2 \lambda^k}\Vert x-x^k\Vert^2_2 \end{aligned} f(x)=f(xk)+f(xk)T(xxk)+2λk1xxk22
则有:
x k + 1 = a r g m i n x ( g ( x ) + f ( x ) ) \begin{aligned} x^{k+1}=\mathop{\mathrm{argmin}} \limits_{x} (g(x) +f(x)) \end{aligned} xk+1=xargmin(g(x)+f(x))
说句实在话,对于上面这种方式,个人表示还能凑合着理解,第一种证明的思路实在是难以想象。

根据前文不动点的性质, x ∗ = p r o v f ( x ∗ ) x^*=prov_f(x^*) x=provf(x),类似 x k + 1 = p r o v λ k g ( x k − λ k ∇ f ( x k ) ) x^{k+1}=prov_{\lambda^k g}(x^k-\lambda^k \nabla f(x^k)) xk+1=provλkg(xkλkf(xk))这种形式迭代方式也称为不动点迭代,

对于Proximal Method的理解

这是我在网上找到的比较能够理解的说法:
对于函数 f + g f+g f+g,给定起点 x k x^{k} xk,首先可微函数 f ( x ) f(x) f(x)沿着起点的负梯度方向,作步长为 λ k \lambda^k λk的梯度下降得到一个预更新值 x k − λ k ∇ f ( x ) x^k-\lambda^k \nabla f(x) xkλkf(x),然后使用近端映射寻找一个 x x x ,这个 x x x 能使得不可微函数 g ( x ) g(x) g(x)足够小,且接近这个预更新值 x k − λ k ∇ f ( x ) x^k-\lambda^k \nabla f(x) xkλkf(x),就用这个 x x x作为本次迭代的更新值 x k + 1 x^{k+1} xk+1

还有一个问题

x k + 1 = p r o v λ k g ( x k − λ k ∇ f ( x k ) ) x^{k+1}=prov_{\lambda^k g}(x^k-\lambda^k \nabla f(x^k)) xk+1=provλkg(xkλkf(xk)),这个迭代算法为什么会成立?
除了不动点迭代外,还有一种解释这里只简单提一下,我也没深入研究(其实是水平不够,看文章太累了。。。),只是看了个皮毛。
∇ f \nabla f f是 Lipschitz continuous的,并且Lipshitz constant是 L L L的情况下,当 λ k ∈ ( 0 , 1 / L ] \lambda^k \in (0,1/L] λk(0,1/L]时,这是一个majorization-minimization method,具体可以查一下这个算法相关的资料。当 λ k > 1 / L \lambda^k > 1/L λk>1/L时,是另外一个问题。

关于不动点迭代的问题,继续解释可以了解:Forward-backward integration of gradient flow。

Proximal Methods的应用

f β ( X ) f_\beta(X) fβ(X)是负对数似然函数,其中 β \beta β是需要求解的参数, X X X是样本数据,我们希望得到下面式子的最小值:
f β ( X ) + λ ∥ β ∥ 1 , 其 中 λ > 0 \begin{aligned} f_\beta(X)+\lambda \Vert \beta \Vert_1,其中 \lambda >0 \end{aligned} fβ(X)+λβ1λ>0
怎么求解 β \beta β呢?
我们直接用 x k + 1 = p r o v λ k g ( x k − λ k ∇ f ( x k ) ) x^{k+1}=prov_{\lambda^k g}(x^k-\lambda^k \nabla f(x^k)) xk+1=provλkg(xkλkf(xk))这个迭代来搞定。
为了计算方便,我们令 ω = β k − λ k ∇ f β k ( x k ) \omega=\beta^k-\lambda^k \nabla f_{\beta^k}(x^k) ω=βkλkfβk(xk),其中 λ k \lambda^k λk中在第 k k k步迭代的步长, β k \beta^k βk是在第 k k k步迭代的 β \beta β
则有:
β k + 1 = p r o v λ g ( ω ) = a r g m i n β k ( λ k λ g ( β k ) + 1 2 ∥ β k − ω ∥ 2 2 ) = a r g m i n β k ( λ ∥ β k ∥ 1 + 1 2 λ k ∥ β k − ω ∥ 2 2 ) \begin{aligned} \beta^{k+1}&=prov_{\lambda g}(\omega)\\ =&\mathop{\mathrm{argmin} }\limits_{\beta_k}(\lambda^k \lambda g(\beta^k) + \frac{1}{2} \Vert \beta^k - \omega \Vert^2_2)\\ =&\mathop{\mathrm{argmin} }\limits_{\beta_k}(\lambda \Vert \beta^k\Vert_1 + \frac{1}{2\lambda^k} \Vert \beta^k - \omega \Vert^2_2)\\ \end{aligned} βk+1===provλg(ω)βkargmin(λkλg(βk)+21βkω22)βkargmin(λβk1+2λk1βkω22)
而:
∥ β k ∥ 1 = ∑ i ∣ β i ∣ , ∥ β k − ω ∥ 2 2 = ∑ i ( β i − ω i ) 2 \Vert \beta^k \Vert_1=\sum_i \vert \beta_i\vert, \Vert \beta^k - \omega \Vert^2_2=\sum_i (\beta_i- \omega_i)^2 βk1=iβi,βkω22=i(βiωi)2
要计算 λ ∥ β k ∥ 1 + 1 2 λ k ∥ β k − ω ∥ 2 2 ) \lambda \Vert \beta^k\Vert_1 + \frac{1}{2\lambda^k} \Vert \beta^k - \omega \Vert^2_2) λβk1+2λk1βkω22)的最小值,我们只要找到每个 λ ∣ β i ∣ + 1 2 λ k ( β i − ω i ) 2 \lambda \vert \beta_i \vert+\frac{1}{2\lambda_k}(\beta_i-\omega_i)^2 λβi+2λk1(βiωi)2的最小值,然后求和就是总体的最小值了。

对于 λ ∣ β i ∣ + 1 2 λ k ( β i − ω i ) 2 \lambda \vert \beta_i \vert+\frac{1}{2\lambda_k}(\beta_i-\omega_i)^2 λβi+2λk1(βiωi)2的最小值,因为有绝对值,需要分类讨论:

  • β i ≥ 0 \beta_i \geq0 βi0

λ ∣ β i ∣ + 1 2 λ k ( β i − ω i ) 2 = 1 2 λ k ( β i 2 + 2 ( λ k λ − ω i ) β i + ω 2 ) \begin{aligned} &\lambda \vert \beta_i \vert+\frac{1}{2\lambda_k}(\beta_i-\omega_i)^2\\ =&\frac{1}{2\lambda_k}(\beta_i^2+2(\lambda_k\lambda -\omega_i)\beta_i+\omega^2) \end{aligned} =λβi+2λk1(βiωi)22λk1(βi2+2(λkλωi)βi+ω2)
此时,当 β i = ω i − λ k λ \beta_i=\omega_i-\lambda_k\lambda βi=ωiλkλ时,取得最小值,由于 β i ≥ 0 \beta_i \geq0 βi0,要求: ω i − λ k λ ≥ 0 \omega_i-\lambda_k\lambda\geq0 ωiλkλ0
但如果: ω i − λ k λ < 0 \omega_i-\lambda_k\lambda<0 ωiλkλ<0 β i \beta_i βi无法取到 ω i − λ k λ \omega_i-\lambda_k\lambda ωiλkλ,当 β i = 0 \beta_i=0 βi=0时,取到最小值。

  • β i < 0 \beta_i<0 βi<0
    λ ∣ β i ∣ + 1 2 λ k ( β i − ω i ) 2 = 1 2 λ k ( β i 2 − 2 ( λ k λ + ω i ) β i + ω 2 ) \begin{aligned} &\lambda \vert \beta_i \vert+\frac{1}{2\lambda_k}(\beta_i-\omega_i)^2\\ =&\frac{1}{2\lambda_k}(\beta_i^2-2(\lambda_k\lambda +\omega_i)\beta_i+\omega^2) \end{aligned} =λβi+2λk1(βiωi)22λk1(βi22(λkλ+ωi)βi+ω2)
    此时,当 β i = ω i + λ k λ \beta_i=\omega_i+\lambda_k\lambda βi=ωi+λkλ时,取得最小值,由于 β i < 0 \beta_i <0 βi<0,要求: ω i + λ k λ < 0 \omega_i+\lambda_k\lambda<0 ωi+λkλ<0
    但如果: ω i + λ k λ > 0 \omega_i+\lambda_k\lambda>0 ωi+λkλ>0 β i \beta_i βi无法取到 ω i + λ k λ \omega_i+\lambda_k\lambda ωi+λkλ,当 β i = 0 \beta_i=0 βi=0时,取到最小值。
  • 综上:
    β i = { ω i − λ k λ ,           ω i > λ k λ 0 ,           − λ λ k < ω < λ λ k ω i + λ k λ ,           ω i < − λ k λ \begin{aligned} \beta_i=\begin{cases} \omega_i-\lambda_k\lambda ,      \omega_i>\lambda_k\lambda\\ 0,     -\lambda \lambda_k<\omega<\lambda\lambda_k\\ \omega_i+\lambda_k\lambda,     \omega_i<-\lambda_k\lambda \end{cases} \end{aligned} βi=ωiλkλ,     ωi>λkλ     λλk<ω<λλkωi+λkλ,     ωi<λkλ

正是由于 − λ λ k < ω < λ λ k -\lambda \lambda_k<\omega<\lambda\lambda_k λλk<ω<λλk时, β i \beta_i βi会出现截断,取值为0时才能取得最小值,才使得损失函数+ L 1 L1 L1正则化时,得到稀疏解。

Proxiaml Methods的实现

这里我就不贴自己写的代码了,直接贴一下王然老师的代码:

  1. 构造一个sigmoid函数:
def sigmoid(x):
    return 0.5 * (jnp.tanh(x / 2) + 1)
  1. 构建逻辑回归模型:
def predict(beta, x):
    return sigmoid(x.dot(beta))
  1. 构造数据
key = random.PRNGKey(0)
x_key, beta_key, beta_test_key = random.split(key,3)
x = random.normal(x_key, (10000, 10))
beta = random.normal(beta_key, (10,))*2.0    #beta是一个列向量
beta_test = random.normal(beta_test_key, (10,))
y = (sigmoid(x.dot(beta))>=0.5).astype(jnp.float32)
  1. 建立逻辑回归的对数似然函数
def loss(beta):
    preds = predict(beta,x)
    #下面用了一个trick,进行了计算简化,如果不简化的话,应该是:y*jnp.log(preds) + (1 - y)jnp.log(1 - preds) ,而由于y只能为0或1,所以可以通过简化用以下的步骤实现:
    label_probs = preds * y + (1 - preds) * (1 - y) 
    return -jnp.sum(jnp.log(label_probs))/10000.00 
  1. 对损失函数求梯度,有两种方式,两个的结果是一样的:
    一是数学推导,如下:
def custom_grad(beta):
    residual = y - predict(beta, x)
    return jnp.transpose(x).dot(-residual)/10000.00

二是通过jax.grad进行计算:

grad_func = jax.grad(loss)
  1. 构造软阈值函数,就是Proximal Method最后那个 β i \beta_i βi。这里是通过jax.lax.cond来实现的,具体的介绍可以看一下官方文档,这个比较简单。
    前面写了那么那么多,在代码实现的时候,只有最后的结论能用的上。。。
def soft_threshold(x, thres):
    return jax.lax.cond(x > thres,
                        lambda _: x - thres,
                        lambda _: jax.lax.cond(
                            x < -thres,
                            lambda _: x + thres,
                            lambda _:0.0,
                            None
                        ),
                        None)
  1. Proximal methods算法的迭代过程,具体我不多介绍了,应该算是一个比较标准的迭代过程。
    特别要说明一下,其实写这些代码的关键在于如何检测每步计算都是正确的,特别是在有向量,矩阵,求导,迭代的过程中,如何验证正确性是很麻烦的,检测的过程是保证结果正确的关键。

另外这里面计算每个 β i \beta_i βi时,用的是jax.vmap实现的并行计算。
对于jax.vmap,可以参考:https://jiayiwu.me/blog/2021/04/05/learning-about-jax-axes-in-vmap.html

def proximal_methods(beta_init, max_iter, eps, lr, penalty):
    converged = False
    beta_old = beta_init
    beta_new = beta_init
    soft_threshold_partial = lambda x: soft_threshold(x, lr*penalty)
    current_iter = 0
    while not converged and current_iter < max_iter:
        print("Current iteration is %d"% current_iter)
        beta_copy = beta_old 
        current_loss = loss(beta_copy) + penalty*jnp.linalg.norm(beta_copy, 1)
        current_grad = custom_grad(beta_copy)
        w = beta_copy - lr*current_grad
        beta_new = jax.vmap(soft_threshold_partial, 0)(w)
        new_loss = loss(beta_new) + penalty*jnp.linalg.norm(beta_new, 1)
        diff = jnp.abs(new_loss-current_loss)
        print("The difference is %.5f"%diff, "   current_loss%.5f"%current_loss, "   new_loss%.5f"%new_loss,)
        beta_old = beta_new
        if diff <= eps:   
            converged = True
            print("Algorithm converged")
            break
        else:
            current_iter +=1
            if current_iter >= max_iter:
                print("The algorithm have failed to converge.")
                break

    return beta_new, converged

参考资料

  1. Proximal Mehtods,Ran Wang
  2. Proximal Algorithm,Neal Parikh,Department of Computer Science Stanford University
  3. 机器学习 | 近端梯度下降法 (proximal gradient descent)
  4. 对近端梯度算法(Proximal Gradient Method)的理解
  5. 【凸优化笔记4】-近端梯度下降(Proximal gradient descent)
  6. Majorization-Minimization优化框架
  7. 浅谈MM优化算法以及CCP算法
  • 3
    点赞
  • 8
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值