pytorch优化器详解:SGD

目录

说明

SGD参数

params

lr

momentum

dampening

weight_decay

nesterov

举例(nesterov为False)

第1轮迭代

第2轮迭代


说明

模型每次反向传导都会给各个可学习参数p计算出一个偏导数g_{t},用于更新对应的参数p。通常偏导数g_{t}不会直接作用到对应的可学习参数p上,而是通过优化器做一下处理,得到一个新的\widehat{g}_t,处理过程用函数F表示(不同的优化器对应的F的内容不同),即\widehat{g}_t=F(g_{t}),然后和学习率lr一起用于更新可学习参数p,即p=p-\widehat{g}_t*lr

SGD参数

SGD是随机梯度下降(stochastic gradient descent)的首字母。

torch.optim.SGD(params,
                lr=<required parameter>,
                momentum=0,
                dampening=0,
                weight_decay=0,
                nesterov=False)

params

模型里需要被更新的可学习参数。

lr

学习率。

momentum

动量值,通过上一次的v和当前的偏导数g,得到本次的v,即v_{t}=v_{t-1}*momentum+g_{t},这个就是上述的函数F。

动量是物理中的概念,它使v具有惯性,这样可以缓和v的抖动,有时候还可以帮助跳出局部盆地。比如上一次计算得到的v是10,参数更新后,本次的偏导数g是0,那么使用momentum=0.9后,最终用于更新可学习参数的v是10*0.9+0=9,而不是0,这样参数仍会得到较大的更新,就会增大离开当前局部盆地的可能性。

dampening

dampening是乘到偏导数g上的一个数,即:v_{t}=v_{t-1}*momentum+g_{t}*(1-dampening)。注意:dampening在优化器第一次更新时,不起作用。

weight_decay

weight_decay的作用是用当前可学习参数p的值修改偏导数,即:g_{t}=g_{t}+(p*weight\_decay),这里待更新的可学习参数p的偏导数就是g_{t}。然后再使用上述公式v_{t}=v_{t-1}*momentum+g_{t}*(1-dampening),计算得到v_{t}

nesterov

对应的文献还没看,从pytorch源码来看,当nesterov为False时,使用上述公式g_{t}=g_{t}+(p*weight\_decay)v_{t}=v_{t-1}*momentum+g_{t}*(1-dampening)计算得到v{t}

当nesterov为True时,在上述得到的v_{t}的基础上,最终的v_{t}=g_{t}+v_{t}*momentum,即又使用了一次momentum和g_{t}

举例(nesterov为False

def test_sgd():
    #定义一个可学习参数w,初值是100
    w = torch.tensor(data=[100], dtype=torch.float32, requires_grad=True)

    #定义SGD优化器,nesterov=False,其余参数都有效
    optimizer = torch.optim.SGD(params=[w], lr=0.1, momentum=0.9, dampening=0.5, weight_decay=0.01, nesterov=False)

    #进行5次优化
    for i in range(5):
        y = w ** 2 #优化的目标是让w的平方,即y尽可能小
        optimizer.zero_grad() #让w的偏导数置零
        y.backward() #反向传播,计算w的偏导数
        optimizer.step() #根据上述两个公式,计算一个v,然后作用到w
        print('grad=%.2f, w=%.2f' % (w.grad, w.data)) #查看w的梯度和更新后的值

'''
输入日志如下:
grad=201.00000, w=79.90000
grad=160.59900, w=53.78005
grad=108.09791, w=24.86720
grad=49.98307, w=-3.65352
grad=-7.34357, w=-28.95499
'''

第1轮更新

w^2的导数是2w,此时w=100,因此g_{1}=200。这里w就是可学习参数p

首先使用weight_decay:g_{1}=g_{1}+(p*weight\_decay)=200+(100*0.01)=201

然后使用momentum(第一次更新不使用dampening):v_{1}=v_{0}*momentum+g_{1}=g_{1}=201。这里v_{0}=0

最后更新w:w=w-lr*v_{1}=100-(0.1*201)=79.9

第2轮更新

w^2的导数是2w,此时w=79.9,因此g_{2}=159.8。这里w就是可学习参数p

首先使用weight_decay:g_{2}=g_{2}+(p*weight\_decay)=159.8+(79.9*0.01)=160.599

然后使用momentum和dampening:v_{2}=v_{1}*momentum+g_{2}*(1-dampening),即:v_{2}=201*0.9+160.599*(1-0.5)=261.1995,这里用到的v_{1}是在SGD类中缓存的。

最后更新w:w=w-lr*v_{2}=79.9-(0.1*261.1995)=53.78005

 

评论 11
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值