pytorch中,当我们需要重置optimizer时,也许想到的是下面这种做法:
opt = optim.Adam(m.parameters(), lr=1e-3
下面举个例子来看下
import torch
import torch.nn as nn
from torch import optim
import collections
m = nn.Linear(3, 1)
opt = optim.Adam(m.parameters(), lr=1e-3)
out = m(torch.rand(3))
out.backward()
opt.step()
print(opt.state)
输出
defaultdict(<class 'dict'>, {Parameter containing:
tensor([[ 0.2271, -0.3494, 0.5265]], requires_grad=True): {'step': 1, 'exp_avg': tensor([[0.0398, 0.0358, 0.0628]]), 'exp_avg_sq': tensor([[0.0002, 0.0001, 0.0004]])}, Parameter containing:
tensor([0.1121], requires_grad=True): {'step': 1, 'exp_avg': tensor([0.1000]), 'exp_avg_sq': tensor([0.0010])}})
重置
opt = optim.Adam(m.parameters(), lr=1e-3)
print(opt.state)
结果optimizer.state置空了
defaultdict(<class 'dict'>, {})
下面是更加常用和优雅的做法
opt.state = collections.defaultdict(dict) # Reset state
print(opt.state)
得到的输出是一样的
defaultdict(<class 'dict'>, {})
collections.defaultdict
有很多妙用,这里不详细说明,推荐参考资料给大家。
Python中collections.defaultdict()使用:https://www.jianshu.com/p/26df28b3bfc8
本文还参考:https://discuss.pytorch.org/t/reset-adaptive-optimizer-state/14654