优化器
Optimizer
1. Optimizer基本属性
class Optimizer(object):
def __init__(self, params, defaults):
self.defaults = defaults
self.state = defaultdict(dict)
self.param_groups = []
def add_param_group(self, param_group):
for group in self.param_groups:
param_set.update(set(group['params’]))
self.param_groups.append(param_group)
def state_dict(self):
return {
'state': packed_state,
'param_groups': param_groups,
}
def load_state_dict(self, state_dict):
基本属性:
• defaults:优化器超参数
• state:参数的缓存,如momentum的缓存
• params_groups:管理的参数组(最重要)
• _step_count:记录更新次数,学习率调整中使用
• zero_grad():清空所管理参数的梯度
• step():执行一步更新
• add_param_group():添加参数组,可以设置不同参数有不同的学习率
• state_dict():获取优化器当前状态信息字典
• load_state_dict() :加载状态信息字典
提示:以下是本篇文章正文内容,下面案例可供参考
2. Optimizer基本方法
1.• step():执行一步更新
#构建可学习参数
weight = torch.randn((2, 2), requires_grad=True)
weight.grad = torch.ones((2, 2))
#构建优化器,传入可学习参数和学习
optimizer = optim.SGD([weight], lr=0.1) #构建优化器,传入可学习参数和学习
#step()
print("weight before step:{}".format(weight.data))
optimizer.step() # 修改lr=1 0.1观察结果
print("weight after step:{}".format(weight.data))
#运行结果:
weight before step:tensor([[0.6614, 0.2669],
[0.0617, 0.6213]])
weight after step:tensor([[ 0.5614, 0.1669],
[-0.0383, 0.5213]])
2.• add_param_group():添加参数组。可以设置不同参数有不同的学习率
print("optimizer.param_groups is\n{}".format(optimizer.param_groups)) #此时只有一组参数,只有一个字典
w2 = torch.randn((3, 3), requires_grad=True)
optimizer.add_param_group({"params": w2, 'lr': 0.0001}) #添加一组参数
print("optimizer.param_groups is\n{}".format(optimizer.param_groups))#现在有两个字典,两个参数有两个不同的学习率,一个0.1,一个0.0001
#运行结果:
optimizer.param_groups is
[{'params': [tensor([[0.6614, 0.2669],
[0.0617, 0.6213]], requires_grad=True)], 'lr': 0.1, 'momentum': 0, 'dampening': 0, 'weight_decay': 0, 'nesterov': False}]
optimizer.param_groups is
[{'params': [tensor([[0.6614, 0.2669],
[0.0617, 0.6213]], requires_grad=True)], 'lr': 0.1, 'momentum': 0, 'dampening': 0, 'weight_decay': 0, 'nesterov': False}, {'params': [tensor([[-0.4519, -0.1661, -1.5228],
[ 0.3817, -1.0276, -0.5631],
[-0.8923, -0.0583, -0.1955]], requires_grad=True)], 'lr': 0.0001, 'momentum': 0, 'dampening': 0, 'weight_decay': 0, 'nesterov': False}]