在加载预训练权重时可能会遇到类似下面的错误:
optimizer.load_state_dict(checkpoint['optimizer_state'])
File "/opt/conda/lib/python3.8/site-packages/torch/optim/optimizer.py", line 145, in load_state_dict
raise ValueError("loaded state dict contains a parameter group "
ValueError: loaded state dict contains a parameter group that doesn't match the size of optimizer's group
遇到这个问题时你去网上看,一般都是泛泛的说原因是因为模型的参数和优化器的参数不匹配,看完保证你还是一头雾水,一般遇到这样情况我干脆直接去翻看torch/optim/optimizer.py里的源码比看网上七嘴八舌甚至胡说八道的瞎说好,例如我使用的pytorch的optimizer.py的源码是这样的:
class Optimizer:
r"""Base class for all optimizers.
.. warning::
Parameters need to be specified as collections that have a deterministic
ordering that is consistent between runs. Examples of objects that don't
satisfy those properties are sets and iterators over values of dictionaries.
Args:
params (iterable): an iterable of :class:`torch.Tensor` s or
:class:`dict` s. Specifies what Tensors should be optimized.
defaults: (dict): a dict containing default values of optimization
options (used when a parameter group doesn't specify them).
"""
def __init__(self, params, defaults):
torch._C._log_api_usage_once("python.optimizer")
self.defaults = defaults
self._optimizer_step_pre_hooks: Dict[int, Callable] = OrderedDict()
self._optimizer_step_post_hooks: Dict[int, Callable] = OrderedDict()
self._patch_step_function()
if isinstance(params, torch.Tensor):
raise TypeError("params argument given to the optimizer should be "
"an iterable of Tensors or dicts, but got " +
torch.typename(params))
self.state = defaultdict(dict)
self.param_groups = []
param_groups = list(params)
if len(param_groups) == 0:
raise ValueError("optimizer got an empty parameter list")
if not isinstance(param_groups[0], dict):
param_groups = [{'params': param_groups}]
for param_group in param_groups:
self.add_param_group(param_group)
# Allows _cuda_graph_capture_health_check to rig a poor man's TORCH_WARN_ONCE in python,
# which I don't think exists
# https://github.com/pytorch/pytorch/issues/72948
self._warned_capturable_if_run_uncaptured = True
...
def load_state_dict(self, state_dict):
r"""Loads the optimizer state.
Args:
state_dict (dict): optimizer state. Should be an object returned
from a call to :meth:`state_dict`.
"""
# deepcopy, to be consistent with module API
state_dict = deepcopy(state_dict)
# Validate the state_dict
groups = self.param_groups
saved_groups = state_dict['param_groups']
if len(groups) != len(saved_groups):
raise ValueError("loaded state dict has a different number of "
"parameter groups")
param_lens = (len(g['params']) for g in groups)
saved_lens = (len(g['params']) for g in saved_groups)
if any(p_len != s_len for p_len, s_len in zip(param_lens, saved_lens)):
raise ValueError("loaded state dict contains a parameter group "
"that doesn't match the size of optimizer's group")
# Update the state
id_map = {old_id: p for old_id, p in
zip(chain.from_iterable((g['params'] for g in saved_groups)),
chain.from_iterable((g['params'] for g in groups)))}
...
可以看到Opimizer类的param_groups是list类型,里面的每个元素是个dict,dict里面至少有params这个key,load_state_dict()里检查目前模型的optimizer的param_groups里的元素个数和预训练权重里读取到的optimizer的param_groups里的元数个数必须一致,并且两个param_groups的对应dict类型的元素里的params key对应的参数tensor(这些参数一般都是模型网络层次里的参数,也就是torch.optim.Adam(model.parameters(), lr=0.1)这样的语句创建optimizer实例时传入的model.parameters(),至于dict里保存的其他参数,例如key是lr时对应的值是学习率超参数,以及和optimizer相关的可学习参数,例如SGD的momentum、Adam的betas等参数)的长度也必须一致!
一般来说,如果你模型训练使用的Optimizer和你要加载的预训练权重保存时的Optimizer一致的话,跟Optimizer本身相关的超参数和可学习参数的个数不会有不同,如果还报上面的错误,那说明你的模型的网络结构和导出预训练权重的模型的网络结构不一致,例如保存预训练权重时的网络结构里有检测头也有分割头,而你们目前要加载权重的模型网络里只有检测头,就会触发上面的错误,要么保持网络结构的一致,要么采用类似下面的办法把预训练参数权重里目前网络结构和Optimizer需要的读取出来保存为一个新的文件,然后再调用load_state_dict()加载即可
net = new_model()
pretrained_weights = torch.load('pretrained_weights.pth')
new_model_dict = net.state_dict()
state_dict = {k:v for k,v in pretrained_weights.items() if k in new_model_dict.keys()}
new_model_dict.update(state_dict)
net.load_state_dict(new_model_dict)