Pytorch 工程重构后模型参数(.pt)读取失败解决方法

Pytorch 工程重构后模型参数(.pt)读取失败解决方法

问题: 近期重构工程遇到了之前训练的模型参数没法读取的问题,即用重构后的工程在测试阶段去加载原来工程训练好的参数会方向有路径问题。

通过探究发现在.pt文件中model一项的类型包含了原来工程的路径信息,需要放置在同一个路径下才可以读取成功。重构type还是比较困难的,但是在测试代码中,我发现大多数模型在加载参数的过程中会将模型参数加载成dict的形式,而忽略没有必要的路径信息。

在这里插入图片描述
为此我们可以建立一个新字典,将state_dict的步骤放置在前面,去除后面不用用到的路径信息,完成模型参数迁移,具体代码如下:

model = torch.load("/home/input/xxx.pt", map_location="cuda:1")  # load checkpoint
new_dict = OrderedDict() 
new_dict["model"] = model["model"].state_dict()
torch.save(new_dict, "/home/output/xxx.pt")
  • 0
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值