梯度下降算法的python实现

前言

梯度下降算法 Gradient Descent GD是沿梯度下降的方向连续迭代逼近求最小值的过程,本文将实现以下梯度下降算法的python实现。

  • 简单梯度下降算法
  • 批量梯度下降算法
  • 随机梯度下降算法

简单梯度下降算法

简单梯度下降算法的核心就是先求出目标函数的导数 g k g_k gk,然后利用简单随机梯度西江算法公式迭代求最小值。
x k + 1 = x k − g k ∗ r k x_{k+1}=x_k-g_k*r_k xk+1=xkgkrk$$

  • x k + 1 x_{k+1} xk+1 下一步位置
  • x k x_{k} xk 当前位置
  • g k g_k gk 为梯度
  • r k r_k rk 学习率,步长

有一个目标函数, f ( x ) = x 2 f(x)=x^2 f(x)=x2,一元函数的导数就是该曲线某一点切线的斜率,导数越大,该点的斜率越大,下降越快
F ′ = 2 x F^{'}=2x F=2x这就是梯度, r k r_k rk 控制着梯度下降前进距离,太大太小都不行。

在这里插入图片描述

代码

import numpy as np
import matplotlib.pyplot as plt


# 定义目标函数 f(x)=x**2+1
def f(x):
    return np.array(x)**2 + 1

# 对目标函数求导 d(x)=x*2
def d1(x):
    return x * 2

def Gradient_Descent_d1(current_x = 0.1,learn_rate = 0.01,e = 0.001,count = 50000):
    # current_x initial x value
    # learn_rate 学习率
    # e error
    # count number of iterations
    for i in range(count):
        grad = d1(current_x) # 求当前梯度
        if abs(grad) < e: # 梯度收敛到控制误差内
            break # 跳出循环
        current_x = current_x - grad * learn_rate # 一维梯度的迭代公式
        print("第{}次迭代逼近值为{}".format(i+1,current_x))

    print("最小值为:",current_x)
    print("最小值保存小数点后6位:%.6f"%(current_x))
    return current_x

# 显示目标函数曲线及梯度下降最小值毕竟情况
def ShowLine(min_X,max_Y):
    x = [x for x in range(10)] + [x * (-1) for x in range(1,10)]
    x.sort()
    print(x)
    plt.plot(x,f(x))
    plt.plot(min_X,max_Y,'ro')
    plt.show()

minValue = Gradient_Descent_d1(current_x = 0.1,learn_rate = 0.01,e = 0.001,count = 50000)
minY = f(minValue)
print('目标函数最小值约为:',minY)
ShowLine(minValue,minY)

输出

1次迭代逼近值为0.0982次迭代逼近值为0.096043次迭代逼近值为0.09411924次迭代逼近值为0.0922368165次迭代逼近值为0.090392079686次迭代逼近值为0.088584238086400017次迭代逼近值为0.086812553324672018次迭代逼近值为0.085076302258178579次迭代逼近值为0.0833747762130149910次迭代逼近值为0.0817072806887546911次迭代逼近值为0.080073135074979612次迭代逼近值为0.0784716723734813次迭代逼近值为0.076902238926010414次迭代逼近值为0.075364194147490215次迭代逼近值为0.0738569102645403916次迭代逼近值为0.0723797720592495917次迭代逼近值为0.070932176618064618次迭代逼近值为0.0695135330857033119次迭代逼近值为0.0681232624239892420次迭代逼近值为0.0667607971755094521次迭代逼近值为0.0654255812319992622次迭代逼近值为0.0641170696073592723次迭代逼近值为0.0628347282152120924次迭代逼近值为0.06157803365090784625次迭代逼近值为0.0603464729778896926次迭代逼近值为0.05913954351833189427次迭代逼近值为0.0579567526479652628次迭代逼近值为0.05679761759500595629次迭代逼近值为0.0556616652431058430次迭代逼近值为0.0545484319382437231次迭代逼近值为0.0534574632994788532次迭代逼近值为0.0523883140334892733次迭代逼近值为0.0513405477528194934次迭代逼近值为0.05031373679776309635次迭代逼近值为0.0493074620618078336次迭代逼近值为0.0483213128205716837次迭代逼近值为0.04735488656416024538次迭代逼近值为0.04640778883287704439次迭代逼近值为0.045479633056219540次迭代逼近值为0.0445700403950951141次迭代逼近值为0.043678639587193242次迭代逼近值为0.0428050667954493443次迭代逼近值为0.04194896545954035544次迭代逼近值为0.04110998615034954645次迭代逼近值为0.04028778642734255646次迭代逼近值为0.039482030698795747次迭代逼近值为0.0386923900848197948次迭代逼近值为0.0379185422831233949次迭代逼近值为0.0371601714374609350次迭代逼近值为0.03641696800871170651次迭代逼近值为0.03568862864853747452次迭代逼近值为0.03497485607556672553次迭代逼近值为0.0342753589540553954次迭代逼近值为0.0335898517749742855次迭代逼近值为0.03291805473947479656次迭代逼近值为0.032259693644685357次迭代逼近值为0.0316144997717915958次迭代逼近值为0.0309822097763557659次迭代逼近值为0.03036256558082864760次迭代逼近值为0.02975531426921207461次迭代逼近值为0.02916020798382783262次迭代逼近值为0.02857700382415127563次迭代逼近值为0.02800546374766824864次迭代逼近值为0.02744535447271488365次迭代逼近值为0.02689644738326058766次迭代逼近值为0.02635851843559537767次迭代逼近值为0.02583134806688346868次迭代逼近值为0.025314721105545869次迭代逼近值为0.02480842668343488370次迭代逼近值为0.02431225814976618571次迭代逼近值为0.0238260129867708672次迭代逼近值为0.02334949272703544273次迭代逼近值为0.0228825028724947374次迭代逼近值为0.02242485281504483675次迭代逼近值为0.0219763557587439476次迭代逼近值为0.0215368286435690677次迭代逼近值为0.0211060920706976878次迭代逼近值为0.02068397022928372779次迭代逼近值为0.02027029082469805380次迭代逼近值为0.01986488500820409281次迭代逼近值为0.0194675873080400182次迭代逼近值为0.01907823556187921283次迭代逼近值为0.0186966708506416384次迭代逼近值为0.01832273743362879585次迭代逼近值为0.01795628268495621786次迭代逼近值为0.01759715703125709387次迭代逼近值为0.01724521389063195288次迭代逼近值为0.01690030961281931589次迭代逼近值为0.0165623034205629390次迭代逼近值为0.0162310573521516791次迭代逼近值为0.01590643620510863492次迭代逼近值为0.01558830748100646193次迭代逼近值为0.01527654133138633394次迭代逼近值为0.01497101050475860695次迭代逼近值为0.01467159029466343496次迭代逼近值为0.01437815848877016597次迭代逼近值为0.01409059531899476298次迭代逼近值为0.01380878341261486799次迭代逼近值为0.01353260774436257100次迭代逼近值为0.013261955589475318101次迭代逼近值为0.012996716477685811102次迭代逼近值为0.012736782148132095103次迭代逼近值为0.012482046505169453104次迭代逼近值为0.012232405575066064105次迭代逼近值为0.011987757463564744106次迭代逼近值为0.01174800231429345107次迭代逼近值为0.01151304226800758108次迭代逼近值为0.01128278142264743109次迭代逼近值为0.01105712579419448110次迭代逼近值为0.01083598327831059111次迭代逼近值为0.010619263612744378112次迭代逼近值为0.01040687834048949113次迭代逼近值为0.010198740773679701114次迭代逼近值为0.009994765958206107115次迭代逼近值为0.009794870639041985116次迭代逼近值为0.009598973226261145117次迭代逼近值为0.009406993761735923118次迭代逼近值为0.009218853886501205119次迭代逼近值为0.009034476808771182120次迭代逼近值为0.008853787272595759121次迭代逼近值为0.008676711527143843122次迭代逼近值为0.008503177296600965123次迭代逼近值为0.008333113750668945124次迭代逼近值为0.008166451475655567125次迭代逼近值为0.008003122446142456126次迭代逼近值为0.007843059997219607127次迭代逼近值为0.007686198797275215128次迭代逼近值为0.00753247482132971129次迭代逼近值为0.0073818253249031155130次迭代逼近值为0.007234188818405053131次迭代逼近值为0.0070895050420369515132次迭代逼近值为0.006947714941196213133次迭代逼近值为0.006808760642372289134次迭代逼近值为0.006672585429524843135次迭代逼近值为0.006539133720934346136次迭代逼近值为0.006408351046515659137次迭代逼近值为0.006280184025585346138次迭代逼近值为0.006154580345073639139次迭代逼近值为0.0060314887381721655140次迭代逼近值为0.005910858963408722141次迭代逼近值为0.005792641784140548142次迭代逼近值为0.005676788948457737143次迭代逼近值为0.005563253169488583144次迭代逼近值为0.005451988106098811145次迭代逼近值为0.005342948343976835146次迭代逼近值为0.005236089377097298147次迭代逼近值为0.005131367589555352148次迭代逼近值为0.005028740237764245149次迭代逼近值为0.00492816543300896150次迭代逼近值为0.004829602124348781151次迭代逼近值为0.004733010081861806152次迭代逼近值为0.004638349880224569153次迭代逼近值为0.004545582882620078154次迭代逼近值为0.004454671224967677155次迭代逼近值为0.004365577800468323156次迭代逼近值为0.004278266244458957157次迭代逼近值为0.004192700919569778158次迭代逼近值为0.004108846901178382159次迭代逼近值为0.004026669963154815160次迭代逼近值为0.003946136563891718161次迭代逼近值为0.003867213832613884162次迭代逼近值为0.0037898695559616066163次迭代逼近值为0.0037140721648423742164次迭代逼近值为0.003639790721545527165次迭代逼近值为0.0035669949071146165166次迭代逼近值为0.003495655008972324167次迭代逼近值为0.0034257419087928777168次迭代逼近值为0.00335722707061702169次迭代逼近值为0.00329008252920468170次迭代逼近值为0.0032242808786205864171次迭代逼近值为0.003159795261048175172次迭代逼近值为0.0030965993558272112173次迭代逼近值为0.003034667368710667174次迭代逼近值为0.0029739740213364538175次迭代逼近值为0.0029144945409097247176次迭代逼近值为0.0028562046500915303177次迭代逼近值为0.0027990805570896997178次迭代逼近值为0.0027430989459479057179次迭代逼近值为0.0026882369670289475180次迭代逼近值为0.0026344722276883687181次迭代逼近值为0.0025817827831346014182次迭代逼近值为0.0025301471274719093183次迭代逼近值为0.002479544184922471184次迭代逼近值为0.0024299533012240217185次迭代逼近值为0.0023813542351995413186次迭代逼近值为0.0023337271504955503187次迭代逼近值为0.0022870526074856394188次迭代逼近值为0.0022413115553359267189次迭代逼近值为0.0021964853242292083190次迭代逼近值为0.002152555617744624191次迭代逼近值为0.0021095045053897317192次迭代逼近值为0.002067314415281937193次迭代逼近值为0.0020259681269762984194次迭代逼近值为0.0019854487644367725195次迭代逼近值为0.001945739789148037196次迭代逼近值为0.0019068249933650763197次迭代逼近值为0.0018686884934977748198次迭代逼近值为0.0018313147236278192199次迭代逼近值为0.0017946884291552628200次迭代逼近值为0.0017587946605721575201次迭代逼近值为0.0017236187673607144202次迭代逼近值为0.0016891463920135203次迭代逼近值为0.00165536346417323204次迭代逼近值为0.0016222561948897654205次迭代逼近值为0.0015898110709919701206次迭代逼近值为0.0015580148495721307207次迭代逼近值为0.001526854552580688208次迭代逼近值为0.0014963174615290743209次迭代逼近值为0.0014663911122984928210次迭代逼近值为0.001437063290052523211次迭代逼近值为0.0014083220242514724212次迭代逼近值为0.001380155583766443213次迭代逼近值为0.0013525524720911142214次迭代逼近值为0.0013255014226492918215次迭代逼近值为0.001298991394196306216次迭代逼近值为0.00127301156631238217次迭代逼近值为0.0012475513349861323218次迭代逼近值为0.0012226003082864098219次迭代逼近值为0.0011981483021206816220次迭代逼近值为0.001174185336078268221次迭代逼近值为0.0011507016293567027222次迭代逼近值为0.0011276875967695687223次迭代逼近值为0.0011051338448341773224次迭代逼近值为0.0010830311679374937225次迭代逼近值为0.001061370544578744226次迭代逼近值为0.001040143133687169227次迭代逼近值为0.0010193402710134258228次迭代逼近值为0.0009989534655931573229次迭代逼近值为0.000978974396281294230次迭代逼近值为0.0009593949083556683231次迭代逼近值为0.000940207010188555232次迭代逼近值为0.0009214028699847839233次迭代逼近值为0.0009029748125850882234次迭代逼近值为0.0008849153163333864235次迭代逼近值为0.0008672170100067186236次迭代逼近值为0.0008498726698065842237次迭代逼近值为0.0008328752164104526238次迭代逼近值为0.0008162177120822435239次迭代逼近值为0.0007998933578405986240次迭代逼近值为0.0007838954906837866241次迭代逼近值为0.0007682175808701109242次迭代逼近值为0.0007528532292527087243次迭代逼近值为0.0007377961646676545244次迭代逼近值为0.0007230402413743014245次迭代逼近值为0.0007085794365468154246次迭代逼近值为0.0006944078478158791247次迭代逼近值为0.0006805196908595615248次迭代逼近值为0.0006669092970423702249次迭代逼近值为0.0006535711111015229250次迭代逼近值为0.0006404996888794924251次迭代逼近值为0.0006276896951019025252次迭代逼近值为0.0006151359011998645253次迭代逼近值为0.0006028331831758672254次迭代逼近值为0.0005907765195123498255次迭代逼近值为0.0005789609891221028256次迭代逼近值为0.0005673817693396607257次迭代逼近值为0.0005560341339528675258次迭代逼近值为0.0005449134512738102259次迭代逼近值为0.0005340151822483339260次迭代逼近值为0.0005233348786033672261次迭代逼近值为0.0005128681810312999262次迭代逼近值为0.0005026108174106739263次迭代逼近值为0.0004925586010624604
最小值为: 0.0004925586010624604
最小值保存小数点后6位:0.000493
目标函数最小值约为: 1.0000002426139756
[-9, -8, -7, -6, -5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9]



在这里插入图片描述

批量梯度下降算法

批量梯度下降,Batch Gradient Descent BGD 算法是对所有的样本数据进行梯度迭代计算,这里“所有”考虑了非凹凸函数(存在多个局部极大值或极小值的情况)
损失函数
J ( θ ) = 1 2 n Σ i = 1 n ( h θ i − y i ) 2 J(\theta)=\frac{1}{2n}\Sigma_{i=1}^{n}(h_{\theta}^{i}-y_i)^2 J(θ)=2n1Σi=1n(hθiyi)2

  • n n n是样本个数
  • 1/2求偏导时可以相互抵消
  • x i , y i x^i,y^i xi,yi是第i个样本的 ( x , y ) (x,y) (x,y)坐标值
    假设函数的公式为:
    h θ ( x i ) = θ 0 + θ 1 x 1 i + θ 2 x 2 i + ⋅ ⋅ ⋅ + θ n x n i h_{\theta}(x^i)=\theta_0+\theta_1x_1^{i}+\theta_2x_2^{i}+···+\theta_nx_n^{i} hθ(xi)=θ0+θ1x1i+θ2x2i++θnxni
    批量梯度是值在对全样本数据(任意维度)计算梯度时,通过计算损失函数求偏导得到梯度计算公式
    $ ∇ θ J ( θ ) = ∂ J ( θ ) ∂ θ j = 1 n Σ i = 1 n ( h θ i − y i ) 2 x j i \nabla_\theta J(\theta)=\frac{\partial J(\theta)}{\partial \theta_j}=\frac{1}{n} \Sigma_{i=1}^{n}(h_{\theta}^{i}-y_i)^2x_j^{i} θJ(θ)=θjJ(θ)=n1Σi=1n(hθiyi)2xji
    i = 1 , 2 , ⋅ ⋅ ⋅ n i=1,2,···n i=1,2,n表示样本数, j = 0 , 1 表 示 特 征 数 j=0,1表示特征数 j=0,1
    批量迭代公式:
    θ = θ − μ ⋅ ∇ θ J ( θ ) \theta=\theta-\mu· \nabla_\theta J(\theta) θ=θμθJ(θ)

随机梯度下降算法

随机梯度下降(Stochastic Gradient Descent SGD)通过每次梯度迭代随机采用一个样本数量,最后逼近极值得出近似预测结果。损失函数计算公式:
J ( θ ) = 1 2 ( h θ ( x i ) − y i ) 2 J(\theta)=\frac{1}{2}(h_\theta(x^i)-y^i)^2 J(θ)=21(hθ(xi)yi)2
特别指出:随机梯度下降算法每次迭代只对一个样本数据进行计算,与批量梯度下降算法相比,损失函数迭代一次无须求所有样本的值,因此不需要求均值。
$ ∇ θ J ( θ ) = ∂ J ( θ ) ∂ θ j = ( h θ i − y i ) 2 x j i \nabla_\theta J(\theta)=\frac{\partial J(\theta)}{\partial \theta_j}=(h_{\theta}^{i}-y_i)^2x_j^{i} θJ(θ)=θjJ(θ)=(hθiyi)2xji
迭代公式:
θ = θ − μ ⋅ ∇ θ J ( θ ) \theta=\theta-\mu·\nabla_\theta J(\theta) θ=θμθJ(θ)

  • 13
    点赞
  • 119
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Wency(王斯-CUEB)

我不是要饭的

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值