【学习笔记】python2读取python3训练的模型(pth文件)

该文介绍了如何将Python3环境下训练的PyTorch模型转换为Python2可读的格式。主要步骤包括加载模型的.pth文件,将state_dict转换为OrderedDict并保存为新的.pth文件。
摘要由CSDN通过智能技术生成

【学习笔记】python2读取python3训练的模型(pth文件)

背景:前面跑强化学习用的Python3的环境进行训练的,现在要结合ROS部署,方便起见使用Python2的环境,发现没法在Python2下直接读取之前的.pth文件。

参考:

  1. PyTorch加载模型model.load_state_dict()问题,Unexpected key(s) in state_dict: “module.features…,Expected .
  2. Attempted to read a PyTorch file with version 3, but the maximum supported version for reading is 2.

一、Python3转Python2

  1. 在保存模型文件的文件夹内打开终端,并切换至训练模型时使用的环境(Python),打开Python:

    python
    
  2. 加载模型文件(policy.pth):

    import torch
    state_dict = torch.load('policy.pth') # 模型可以保存为pth文件,也可以为pt文件。
    
  3. 读取模型文件,并保存至有序字典中:

    from collections import OrderedDict
    new_state_dict = OrderedDict()
    for k, v in state_dict.items():
        new_state_dict[k] = v
    
  4. 保存:

    torch.save(new_state_dict, 'new_policy.pth', _use_new_zipfile_serialization=False)
    
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值