PyTorch实现梯度累加变相扩大batch

当显卡内存有限或无法处理大batch训练时,可以通过调整PyTorch中的梯度计算和参数更新策略。文章介绍了如何在每个小batch后平均损失,积累梯度,然后在每批数据后更新参数,以模拟大batch的效果。
摘要由CSDN通过智能技术生成

 本篇文章,主要针对显卡过小或者不容易实现大batch来求梯度更新网络参数

1)常规情况下,pytorch求梯度和进行网络参数更新

    outputs = model(inputs)
    loss = criterion(outputs,inputs)

    optimizer.zero_grad() #清空梯度
    loss.backward()       #反向传播,求梯度
    optimizer.step()      #根据优化器更新网络参数

2)显卡过小或者不容易实现大batch来求梯度更新网络参数,但又想试一下呢,可以按照以下代码进行模拟

    outputs = model(inputs)
    loss = criterion(outputs,inputs)

    loss = loss/batch_size   #相当于平均了loss
    loss.backward()          #求梯度,后面没有马上清0

    if cnt%batch_size==0:
        optimizer.step()     #根据累计到batch_size个梯度,进行网络参数更新
        optimizer.zero_grad()#梯度清0   

 

 

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值