model.load_state_dict(state_dict, strict=False)

我们在构造好了一个模型后,可能要加载一些训练好的模型参数。举例子如下:

假设  trained.pth 是一个训练好的网络的模型参数存储

model = Net()是我们刚刚生成的一个新模型,我们希望model将trained.pth中的参数加载加载进来,但是model中多了一些trained.pth中不存在的参数,如果使用下面的命令:

state_dict = torch.load('trained.pth')
model.load_state_dict(state_dict)

会报错,说key对应不上,因为model你强人所难,我堂堂trained.pth没有你的那些个零碎玩意,你非要向我索取,我上哪给你弄去。

但是model不干,说既然你不能完全满足我的需要,那么你有什么我就拿什么吧,怎么办呢?下面的指令代码就行了。

model.load_state_dict(state_dict, strict=False)
  • 247
    点赞
  • 310
    收藏
    觉得还不错? 一键收藏
  • 34
    评论
`model.load_state_dict()` 函数是 PyTorch 中用于加载模型参数的函数。它的作用是将预训练或保存的模型参数应用到指定的模型对象上。 `load_state_dict()` 函数的基本语法如下: ```python model.load_state_dict(state_dict, strict=True) ``` 其中,`state_dict` 是一个包含模型参数的字典对象,它通常是通过 `torch.load()` 函数加载预训练或保存的模型文件得到的。`strict` 是一个布尔值参数,用于指定是否严格加载参数。 使用 `load_state_dict()` 函数可以完成以下任务: 1. 加载预训练模型参数:可以将预训练模型的权重加载到指定的模型对象中。通常,需要先创建一个与预训练模型结构相同的空模型对象,然后使用 `load_state_dict()` 函数将预训练模型的参数应用到该模型对象上。 2. 加载保存的模型参数:可以将保存的模型参数加载到指定的模型对象中。在使用 `torch.save()` 函数保存模型时,通常使用 `model.state_dict()` 方法获取模型的参数字典,然后将其保存到文件中。加载时,可以使用 `torch.load()` 函数加载保存的模型文件,并使用 `load_state_dict()` 函数将加载的参数应用到模型对象上。 示例代码: ```python # 创建空模型对象 model = MyModel() # 加载预训练模型参数 pretrained_state_dict = torch.load('pretrained_model.pt') model.load_state_dict(pretrained_state_dict) # 或者加载保存的模型参数 saved_state_dict = torch.load('saved_model.pt') model.load_state_dict(saved_state_dict) ``` 通过以上代码,可以加载预训练模型的参数或保存的模型的参数,并将其应用到 `MyModel` 类型的 `model` 对象上。这样,`model` 对象就具有了与预训练模型或保存的模型相匹配的权重。
评论 34
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值