pytorch load state dict_源码详解Pytorch的state_dict和load_state_dict

本文详细介绍了PyTorch中模型的state_dict和load_state_dict的原理与实现。state_dict存储了网络结构的名字和参数,通过递归方式读取。load_state_dict函数用于恢复模型参数,它会对比保存的模型参数和当前模型结构,确保参数正确加载,遇到不匹配项可以选择忽略或报错。
摘要由CSDN通过智能技术生成

在 Pytorch 中一种模型保存和加载的方式如下:

# save
torch.save(model.state_dict(), PATH)

# load
model = MyModel(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.eval()

model.state_dict()其实返回的是一个OrderDict,存储了网络结构的名字和对应的参数,下面看看源代码如何实现的。

state_dict

# torch.nn.modules.module.py
class Module(object):
    def state_dict(self, destination=None, prefix='', keep_vars=False):
        if destination is None:
            destination = OrderedDict()
            destination._metadata = OrderedDict()
        destination._metadata[prefix[:-1]] = local_metadata = dict(version=self._version)
        for name, param in self._parameters.items():
            if param is not None:
                destination[prefix + name] = param if keep_vars else param.data
        for name, buf in self._buffers.items():
            if buf is not None:
                destination[prefix + name] = buf if keep_vars else buf.data
        for name, module in self._modules.items():
            if module is not None:
                module.state_dict(destination, prefix + name + '.', keep_vars=keep_vars)
        for hook in self._state_dict_hooks.values():
            hook_result = hook(self, destination, prefix, local_metadata)
            if hook_result is not None:
                destination = hook_result
        return destination

可以看到state_dict函数中遍历了4中元素,分别是_paramters,_buffers,

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值