本篇文章,主要针对显卡过小或者不容易实现大batch来求梯度更新网络参数
1)常规情况下,pytorch求梯度和进行网络参数更新
outputs = model(inputs)
loss = criterion(outputs,inputs)
optimizer.zero_grad() #清空梯度
loss.backward() #反向传播,求梯度
optimizer.step() #根据优化器更新网络参数
2)显卡过小或者不容易实现大batch来求梯度更新网络参数,但又想试一下呢,可以按照以下代码进行模拟
outputs = model(inputs)
loss = criterion(outputs,inputs)
loss = loss/batch_size #相当于平均了loss
loss.backward() #求梯度,后面没有马上清0
if cnt%batch_size==0:
optimizer.step() #根据累计到batch_size个梯度,进行网络参数更新
optimizer.zero_grad()#梯度清0