optimizer是PyTorch更新模型参数的工具。PyTorch先定义一个基类Optimizer来实现优化器的基本功能,再用子类实现每一个优化算法相应的优化过程,如SGD、Adam等。
class Optimizer(object):
r"""Base class for all optimizers.
.. warning::
Parameters need to be specified as collections that have a deterministic
ordering that is consistent between runs. Examples of objects that don't
satisfy those properties are sets and iterators over values of dictionaries.
Arguments:
params (iterable): an iterable of :class:`torch.Tensor` s or :class:`dict` s. Specifies what Tensors should be optimized.
defaults: (dict): a dict containing default values of optimization
options (used when a parameter group doesn't specify them).
"""
def __init__(self, params, defaults):
torch._C._log_api_usage_once("python.optimizer")
# defaults是lr/momentun等对待优化变量有全局影响的参数,子类将其初始化为字典
self.defaults = defaults
# params必须是由Tensor或字典构成的可迭代对象
if isinstance(params, torch.Tensor):
raise TypeError("params argument given to the optimizer should be "
"an iterable of Tensors or dicts, but got " +
torch.typename(params))
# state是一个有默认值的字典,默认值类型为字典;保存optimizer的当前状态
self.state = defaultdict(dict)
# self.param_groups保存所有待优化的参数;其中的每一项都是一个字典,对应一组待优化参数及相关的参数
self.param_groups = []
param_groups = list(params) # 所有要被optimizer优化的变量,不可为空
if len(param_groups) == 0:
raise ValueError("optimizer got an empty parameter list")
# 将被优化的变量以字典的形式保存为列表中的一项;
if not isinstance(param_groups[0], dict):
param_groups = [{'params': param_groups}]
# 将param_groups中的所有项添加到self.param_groups中
for param_group in param_groups:
self.add_param_group(param_group)
在构造函数中将所有待优化的参数以字典的形式保存到列表中,进而再添加到self.param_groups中。这样做的目的是在fine-tune时,方便通过key-value的形式访问相应的数据。下面来看 self.add_param_group():
def add_param_group(self, param_group):
r"""Add a param group to the :class:`Optimizer` s `param_groups`.
This can be useful when fine tuning a pre-trained network as frozen layers can be made trainable and added to the :class:`Optimizer` as training progresses.
Arguments:
param_group (dict): Specifies what Tensors should be optimized along with group specific optimization options.
"""
assert isinstance(param_group, dict), "param group must be a dict"
params = param_group['params']
# 将所有的参数封装成一个列表;此时'params'对应的value就是一个列表,其中是所有要被优化的变量
if isinstance(params, torch.Tensor):
param_group['params'] = [params]
elif isinstance(params, set):
raise TypeError('optimizer parameters need to be organized in ordered collections, but the ordering of tensors in sets will change between runs. Please use a list instead.')
else:
param_group['params'] = list(params)
# 待优化变量必须是torch.Tensor类型,且必须是叶节点(显式定义的变量)
for param in param_group['params']:
if not isinstance(param, torch.Tensor):
raise TypeError("optimizer can only optimize Tensors, but one of the params is " + torch.typename(param))
if not param.is_leaf:
raise ValueError("can't optimize a non-leaf Tensor")
# 将其他参数添加为字典中的一项
for name, default in self.defaults.items():
if default is required and name not in param_group:
raise ValueError("parameter group didn't specify a value of required optimization parameter " + name)
else:
param_group.setdefault(name, default)
# 借助set来判断'params'项是否已经存在于self.param_groups中
param_set = set()
for group in self.param_groups:
param_set.update(set(group['params']))
if not param_set.isdisjoint(set(param_group['params'])):
raise ValueError("some parameters appear in more than one parameter group")
# isdisjoint判断两个集合是否含有相同的元素,返回布尔值
# 将所有相关参数添加到self.param_groups中
self.param_groups.append(param_group)
接下来看参数保存和加载的两个函数:
获取optimizer的参数:state_dict
def state_dict(self):
r"""Returns the state of the optimizer as a :class:`dict`.
It contains two entries:
* state - a dict holding current optimization state. Its content
differs between optimizer classes.
* param_groups - a dict containing all parameter groups
"""
# Save ids instead of Tensors
def pack_group(group):
packed = {k: v for k, v in group.items() if k != 'params'}
packed['params'] = [id(p) for p in group['params']]
return packed
# self.param_groups中的每一项(字典)重新以字典形式返回,并封装在一个列表中,即param_groups的数据组织形式与self.param_groups完全相同,区别是'params'这一项数据不再是Tensor,而是Tensor的地址。即原来保存的是变量,现在保存的变量对应的对象的地址。
param_groups = [pack_group(g) for g in self.param_groups]
# 将state中的所有Tensor替换为相应的对象的地址
packed_state = {(id(k) if isinstance(k, torch.Tensor) else k): v
for k, v in self.state.items()}
# 以字典的形式返回
return {
'state': packed_state,
'param_groups': param_groups,
}
在上式中需要注意的是,变量返回的都是对象地址,而不是变量值。
加载本地保存的参数:load_state_dict
def load_state_dict(self, state_dict):
r"""Loads the optimizer state.
Arguments:
state_dict (dict): optimizer state. Should be an object returned from a call to :meth:`state_dict`.
"""
# deepcopy, to be consistent with module API
state_dict = deepcopy(state_dict)
# 检查当前optimizer的参数是否与要加载的数据一致
groups = self.param_groups
saved_groups = state_dict['param_groups']
if len(groups) != len(saved_groups):
raise ValueError("loaded state dict has a different number of parameter groups")
param_lens = (len(g['params']) for g in groups)
saved_lens = (len(g['params']) for g in saved_groups)
# 检查每一个'params'中的每一个变量
if any(p_len != s_len for p_len, s_len in zip(param_lens, saved_lens)):
raise ValueError("loaded state dict contains a parameter group that doesn't match the size of optimizer's group")
# 以字典的形式建立旧对象地址和新对象地址的映射
id_map = {old_id: p for old_id, p in
zip(chain(*(g['params'] for g in saved_groups)),
chain(*(g['params'] for g in groups)))}
# dtype或device的转换
def cast(param, value):
r"""Make a deep copy of value, casting all tensors to device of param."""
if isinstance(value, torch.Tensor):
# Floating-point types are a bit special here. They are the only ones
# that are assumed to always match the type of params.
if param.is_floating_point():
value = value.to(param.dtype)
value = value.to(param.device)
return value
elif isinstance(value, dict):
return {k: cast(param, v) for k, v in value.items()}
elif isinstance(value, container_abcs.Iterable):
return type(value)(cast(param, v) for v in value)
else:
return value
# Copy state assigned to params (and cast tensors to appropriate types).
# State that is not assigned to params is copied as is (needed for
# backward compatibility).
state = defaultdict(dict)
for k, v in state_dict['state'].items():
if k in id_map: # 旧对象地址
param = id_map[k] # 新对象地址
state[param] = cast(param, v)
else:
state[k] = v
# 参数更新
def update_group(group, new_group):
new_group['params'] = group['params']
return new_group
param_groups = [update_group(g, ng) for g, ng in zip(groups, saved_groups)]
self.__setstate__({'state': state, 'param_groups': param_groups})
其他函数:
清空梯度:zero_grad
# 清空所有待优化参数的梯度。由于pytorch中Tensor的梯度默认是累加的,故模型训练时要正确计算每次反向传播的梯度,都要对之前的梯度清零。
def zero_grad(self):
r"""Clears the gradients of all optimized :class:`torch.Tensor` s."""
for group in self.param_groups:
for p in group['params']:
if p.grad is not None:
p.grad.detach_()
p.grad.zero_()
单步执行:step
# 该函数由子类实现
def step(self, closure):
r"""Performs a single optimization step (parameter update).
Arguments:
closure (callable): A closure that reevaluates the model and
returns the loss. Optional for most optimizers.
.. note::
Unless otherwise specified, this function should not modify the
``.grad`` field of the parameters.
"""
raise NotImplementedError
下面以SGD为例,具体讲解optimizer的原理:
SGD的一般用法为:(代码来源:pytorch/examples/imagenet/main.py)
优化器定义:
optimizer = torch.optim.SGD(model.parameters(), args.lr, momentum=args.momentum,
weight_decay=args.weight_decay)
模型训练时:
optimizer.zero_grad() # 历史梯度清零
output = model(image) # 计算前向输出
loss = criterion(output, target) # 计算loss
loss.backward() # 计算当前梯度
optimizer.step() # 变量更新
下面来看optim.SGD的源码:
class SGD(Optimizer):
r"""Implements stochastic gradient descent (optionally with momentum).
Nesterov momentum is based on the formula from
`On the importance of initialization and momentum in deep learning`__.
Args:
params (iterable): iterable of parameters to optimize or dicts defining
parameter groups
lr (float): learning rate
momentum (float, optional): momentum factor (default: 0)
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
dampening (float, optional): dampening for momentum (default: 0)
nesterov (bool, optional): enables Nesterov momentum (default: False)
"""
def __init__(self, params, lr=required, momentum=0, dampening=0,
weight_decay=0, nesterov=False):
# 对lr/momentum/weight_decay等参数进行检查
if lr is not required and lr < 0.0:
raise ValueError("Invalid learning rate: {}".format(lr))
if momentum < 0.0:
raise ValueError("Invalid momentum value: {}".format(momentum))
if weight_decay < 0.0:
raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
# 将除待优化的变量之外的参数封装成一个字典,用于初始化父类中的defaults参数
defaults = dict(lr=lr, momentum=momentum, dampening=dampening,
weight_decay=weight_decay, nesterov=nesterov)
if nesterov and (momentum <= 0 or dampening != 0):
raise ValueError("Nesterov momentum requires a momentum and zero dampening")
super(SGD, self).__init__(params, defaults)
由上述代码可以看出SGD的构造函数主要进行一些参数的检查和封装,主要的初始化操作由Optimizer类来完成。
SGD类中另一个自己单独实现的函数:step()。
step()更据当前的梯度对变量进行更新。SGD的更新公式为:
标准公式: w t + 1 = w t + lr ∗ λ ∗ ▽ w J ( w ) w_{t+1} = w_{t} + \text{lr} * \lambda * \bigtriangledown_{w}J(w) wt+1=wt+lr∗λ∗▽wJ(w)( λ \lambda λ是weight_decay)
momentum: v t + 1 = m ∗ v t − lr ∗ λ ∗ ▽ w J ( w ) v_{t+1} = m * v_{t} - \text{lr} * \lambda * \bigtriangledown_{w}J(w) vt+1=m∗vt−lr∗λ∗▽wJ(w), w t + 1 = w t − v t + 1 w_{t+1} = w_{t} - v_{t+1} wt+1=wt−vt+1
nesterov momentum: v t + 1 = m ∗ v t − lr ∗ λ ∗ ▽ w J ( w − m ∗ v t ) v_{t+1} = m * v_{t} - \text{lr} * \lambda * \bigtriangledown_{w}J(w - m*v_{t}) vt+1=m∗vt−lr∗λ∗▽wJ(w−m∗vt), w t + 1 = w t − v t + 1 w_{t+1} = w_{t} - v_{t+1} wt+1=wt−vt+1
上述公式是通用公式,其他框架也是这样实现的,但PyTorch中的实现略有不同,改变了学习率计算的位置,即:
v t + 1 = m ∗ v t − λ ∗ ▽ w J ( w − m ∗ v t ) v_{t+1} = m * v_{t} - \lambda * \bigtriangledown_{w}J(w - m*v_{t}) vt+1=m∗vt−λ∗▽wJ(w−m∗vt)
= m ∗ v t − λ ∗ ▽ w J ( w ) + λ ∗ ▽ w J ( m ∗ v t ) ) = m * v_{t} - \lambda * \bigtriangledown_{w}J(w) + \lambda * \bigtriangledown_{w}J(m * v_{t})) =m∗vt−λ∗▽wJ(w)+λ∗▽wJ(m∗vt))
= m ∗ v t − λ ∗ ▽ w J ( w ) + m ∗ λ ∗ ▽ w J ( v t ) ) = m * v_{t} - \lambda * \bigtriangledown_{w}J(w) + m * \lambda * \bigtriangledown_{w}J(v_{t})) =m∗vt−λ∗▽wJ(w)+m∗λ∗▽wJ(vt)),
w t + 1 = w t − lr ∗ v t + 1 w_{t+1} = w_{t} - \text{lr} * v_{t+1} wt+1=wt−lr∗vt+1
@torch.no_grad()
def step(self, closure=None):
"""Performs a single optimization step.
Arguments:
closure (callable, optional): A closure that reevaluates the model and returns the loss.
"""
# 根据closure重新计算loss
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()
# 根据计算好的变量的梯度对变量进行更新
for group in self.param_groups:
weight_decay = group['weight_decay']
momentum = group['momentum']
dampening = group['dampening']
nesterov = group['nesterov']
for p in group['params']:
if p.grad is None:
continue
d_p = p.grad
if weight_decay != 0:
d_p = d_p.add(p, alpha=weight_decay) # L2正则化
if momentum != 0:
param_state = self.state[p]
if 'momentum_buffer' not in param_state:
# 历史更新量v_{t}
buf = param_state['momentum_buffer'] = torch.clone(d_p).detach()
else:
buf = param_state['momentum_buffer']
# v_{t+1} = m * v_{t} + (1 - dampening) * \bigtriangledown_{w}J(w)
# dampening的作用自己理解是控制梯度的大小,以防出现梯度爆炸;
# 使用nesterov时必须设为0
buf.mul_(momentum).add_(d_p, alpha=1 - dampening)
# 上式同时通过buf对self.state进行了更新
if nesterov:
# \bigtriangledown_{w}J(w) + m * v_{t+1} ??? PyTorch对原公式的修改
d_p = d_p.add(buf, alpha=momentum)
else:
d_p = buf
# w_{t+1} = w_{t} - \text{lr} * v_{t+1}
p.add_(d_p, alpha=-group['lr'])
# 注意此处使用的是改变对象值得add_方法。这样optimizer对模型参数的更新可以在模型中体现出来。
# 每轮循环只用了一个'params'和相应的defaults参数
return loss
根据上述代码可以看出优化器每个循环中都根据’params’这个key在字典中取相应的value进行更新,且相关参数也是与这个’params’对应的。这样做的目的是更灵活的对模型参数进行优化。比如,我只想对模型中的部分参数进行正则化。例:(代码来源:MetaPruning/mobilenetv2/evaluating/evaluate.py )
# split the weight parameter that need weight decay
all_parameters = model.parameters()
weight_parameters = []
for pname, p in model.named_parameters():
if 'fc' in pname or 'conv1' in pname or 'pwconv' in pname:
weight_parameters.append(p)
weight_parameters_id = list(map(id, weight_parameters))
other_parameters = list(filter(lambda p: id(p) not in weight_parameters_id, all_parameters))
# define the optimizer
optimizer = torch.optim.SGD(
[{'params' : other_parameters},
{'params' : weight_parameters, 'weight_decay' : args.weight_decay}],
args.learning_rate,
momentum=args.momentum,
)
在上述代码中,只对’fc’、‘conv1’和‘pwconv’层中的变量做正则化,对其他变量不做正则化,则需要将两部分变量分离,添加到optimizer中时需注意:
- 两部分变量必须都包含在一个列表中,列表中的每一项是带优化的一部分变量,且每一项都必须是字典,字典的key必须是’params’;
- 某一部分变量如果有专有的训练参数,如上述代码中的weight_decay,则该参数必须与该部分变量在一个字典中,且该参数的key必须与PyTorch中的相关定义相同;
- 不在列表中的其他参数如learning_rate和momentum则对所有待优化变量起作用。
再通过一个实例来更清晰地认识state_dict:
import torch
import torch.nn as nn
class net(nn.Module):
def __init__(self):
super(net, self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(1, 1, 1, stride=1, padding=0, bias=False),
nn.BatchNorm2d(1),
nn.ReLU(inplace=True),
)
self.fc = nn.Linear(4, 2)
def forward(self, x):
x = self.conv(x)
x = x.view(-1, 4)
x = self.fc(x)
return x
model = net()
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD([{'params':model.conv.parameters(), 'lr':0.01},
{'params':model.fc.parameters(), 'lr':0.02}],
momentum=0.9,
weight_decay=1e-5)
x = torch.rand(1, 1, 2, 2)
y = torch.tensor([1])
optimizer.zero_grad()
output = model(x)
loss = criterion(output, y)
loss.backward()
optimizer.step()
print(optimizer.state_dict())
输出结果为:
{
'state': {
2387458897384: {'momentum_buffer': tensor([[[[-0.0003]]]])}, # shape:[1, 1, 1, 1]
2387458897456: {'momentum_buffer': tensor([-0.1267])}, # shape:[1]
2387458897528: {'momentum_buffer': tensor([-0.2534])}, # shape:[1]
2387458897816: {'momentum_buffer':tensor([[2.7221e-01, 5.9235e-01, 1.5873e-06, 5.2255e-07],
[-2.7221e-01, -5.9235e-01, -4.2629e-06, -4.6668e-06]])}, # shape:[2, 4]
2387458897888: {'momentum_buffer': tensor([ 0.5391, -0.5391])}}, # shape:[2]
'param_groups': [
{
'lr': 0.01,
'momentum': 0.9,
'dampening': 0,
'weight_decay': 1e-05,
'nesterov': False,
'params': [2387458897384, 2387458897456,2387458897528]
},
{
'lr': 0.02,
'momentum': 0.9,
'dampening': 0,
'weight_decay': 1e-05,
'nesterov': False,
'params': [2387458897816, 2387458897888]
}
]
}
由上述结果可知:
- state_dict是一个字典,包含’state’和’param_group’两项。
- 'state’是一个字典,其中保存的是optimizer更新变量过程中计算出的最新的相关缓存变量。key是这些缓存的地址,value也是一个字典,key是缓存变量名,value是相应的tensor。
- ‘param_groups’是一个列表,列表中的每一项是一个字典,表示一组待优化的变量及其相关更新参数。在每一项中,key是相应的变量名,value是对应的值。需要注意的是,所有待优化的变量以地址的形式保存在一个列表中,对应的key是‘params’。