optimizer.load_state_dict()报错parameter group不匹配的问题的原因

在加载预训练权重时可能会遇到类似下面的错误:

optimizer.load_state_dict(checkpoint['optimizer_state'])
  File "/opt/conda/lib/python3.8/site-packages/torch/optim/optimizer.py", line 145, in load_state_dict
    raise ValueError("loaded state dict contains a parameter group "
ValueError: loaded state dict contains a parameter group that doesn't match the size of optimizer's group

遇到这个问题时你去网上看,一般都是泛泛的说原因是因为模型的参数和优化器的参数不匹配,看完保证你还是一头雾水,一般遇到这样情况我干脆直接去翻看torch/optim/optimizer.py里的源码比看网上七嘴八舌甚至胡说八道的瞎说好,例如我使用的pytorch的optimizer.py的源码是这样的:

class Optimizer:
    r"""Base class for all optimizers.

    .. warning::
        Parameters need to be specified as collections that have a deterministic
        ordering that is consistent between runs. Examples of objects that don't
        satisfy those properties are sets and iterators over values of dictionaries.

    Args:
        params (iterable): an iterable of :class:`torch.Tensor` s or
            :class:`dict` s. Specifies what Tensors should be optimized.
        defaults: (dict): a dict containing default values of optimization
            options (used when a parameter group doesn't specify them).
    """

    def __init__(self, params, defaults):
        torch._C._log_api_usage_once("python.optimizer")
        self.defaults = defaults
        self._optimizer_step_pre_hooks: Dict[int, Callable] = OrderedDict()
        self._optimizer_step_post_hooks: Dict[int, Callable] = OrderedDict()

        self._patch_step_function()

        if isinstance(params, torch.Tensor):
            raise TypeError("params argument given to the optimizer should be "
                            "an iterable of Tensors or dicts, but got " +
                            torch.typename(params))

        self.state = defaultdict(dict)
        self.param_groups = []

        param_groups = list(params)
        if len(param_groups) == 0:
            raise ValueError("optimizer got an empty parameter list")
        if not isinstance(param_groups[0], dict):
            param_groups = [{'params': param_groups}]

        for param_group in param_groups:
            self.add_param_group(param_group)

        # Allows _cuda_graph_capture_health_check to rig a poor man's TORCH_WARN_ONCE in python,
        # which I don't think exists
        # https://github.com/pytorch/pytorch/issues/72948
        self._warned_capturable_if_run_uncaptured = True
        ...

    def load_state_dict(self, state_dict):
        r"""Loads the optimizer state.

        Args:
            state_dict (dict): optimizer state. Should be an object returned
                from a call to :meth:`state_dict`.
        """
        # deepcopy, to be consistent with module API
        state_dict = deepcopy(state_dict)
        # Validate the state_dict
        groups = self.param_groups
        saved_groups = state_dict['param_groups']

        if len(groups) != len(saved_groups):
            raise ValueError("loaded state dict has a different number of "
                             "parameter groups")
        param_lens = (len(g['params']) for g in groups)
        saved_lens = (len(g['params']) for g in saved_groups)
        if any(p_len != s_len for p_len, s_len in zip(param_lens, saved_lens)):
            raise ValueError("loaded state dict contains a parameter group "
                             "that doesn't match the size of optimizer's group")

        # Update the state
        id_map = {old_id: p for old_id, p in
                  zip(chain.from_iterable((g['params'] for g in saved_groups)),
                      chain.from_iterable((g['params'] for g in groups)))}
        
        ...

可以看到Opimizer类的param_groups是list类型,里面的每个元素是个dict,dict里面至少有params这个key,load_state_dict()里检查目前模型的optimizer的param_groups里的元素个数和预训练权重里读取到的optimizer的param_groups里的元数个数必须一致,并且两个param_groups的对应dict类型的元素里的params key对应的参数tensor(这些参数一般都是模型网络层次里的参数,也就是torch.optim.Adam(model.parameters(), lr=0.1)这样的语句创建optimizer实例时传入的model.parameters(),至于dict里保存的其他参数,例如key是lr时对应的值是学习率超参数,以及和optimizer相关的可学习参数,例如SGD的momentum、Adam的betas等参数)的长度也必须一致!

一般来说,如果你模型训练使用的Optimizer和你要加载的预训练权重保存时的Optimizer一致的话,跟Optimizer本身相关的超参数和可学习参数的个数不会有不同,如果还报上面的错误,那说明你的模型的网络结构和导出预训练权重的模型的网络结构不一致,例如保存预训练权重时的网络结构里有检测头也有分割头,而你们目前要加载权重的模型网络里只有检测头,就会触发上面的错误,要么保持网络结构的一致,要么采用类似下面的办法把预训练参数权重里目前网络结构和Optimizer需要的读取出来保存为一个新的文件,然后再调用load_state_dict()加载即可

net = new_model()
pretrained_weights = torch.load('pretrained_weights.pth')
new_model_dict = net.state_dict()
state_dict = {k:v for k,v in pretrained_weights.items() if k in new_model_dict.keys()}
new_model_dict.update(state_dict)
net.load_state_dict(new_model_dict)

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Arnold-FY-Chen

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值