【PyTorch 踩坑日记】50 系显卡 + CUDA 12.8 + PyTorch 2.8 下加载模型报错 _pickle.UnpicklingError 的终极解决方案

【PyTorch 踩坑日记】50 系显卡 + CUDA 12.8 + PyTorch 2.8 下加载模型报错 _pickle.UnpicklingError 的终极解决方案(含 MMDetection 安装教程)


✅ 针对 RTX 5090 / 5080 / 5070 新显卡环境
✅ 报错 _pickle.UnpicklingErrorWeights only load failed 全面解析
✅ PyTorch 2.6 默认 weights_only=True 机制详解
✅ 适配 CUDA 12.8 的 MMDetection 安装教程
✅ 多种完整解决方案 + 可复用代码


📌 一、问题背景

随着 NVIDIA 发布 50 系显卡(如 RTX 5090 / 5080 / 5070),新硬件仅支持 CUDA 12.8 及以上版本。这对我们的深度学习环境带来重大影响:

  • 很多旧版本 PyTorch 和深度学习框架 不支持 CUDA 12.8
  • 必须升级 PyTorch 到 2.6 或更高版本
  • 升级后使用 torch.load 加载模型时,会遇到严重报错!

❗ 二、典型报错场景

假设我们在使用 MMDetection,执行如下命令加载模型:

python demo/image_demo.py demo.jpg \
    configs/faster_rcnn_r50_fpn_1x_coco.py \
    checkpoints/faster_rcnn_r50.pth

出现以下完整报错信息:

File "C:\ProgramData\miniconda3\envs\py39\lib\site-packages\torch\serialization.py", line 1548, in load 
    raise pickle.UnpicklingError(_get_wo_message(str(e))) from None
_pickle.UnpicklingError: Weights only load failed. This file can still be loaded, to do so you have two options, do those steps only if you trust the source of the checkpoint.
        (1) In PyTorch 2.6, we changed the default value of the `weights_only` argument in `torch.load` from `False` to `True`. Re-running `torch.load` with `weights_only` set to `False` will likely succeed, but it can result in arbitrary code execution. Do it only if you got the file from a trusted source.
        (2) Alternatively, to load with `weights_only=True` please check the recommended steps in the following error message.
        WeightsUnpickler error: Unsupported global: GLOBAL numpy.core.multiarray._reconstruct was not an allowed global by default. Please use `torch.serialization.add_safe_globals([numpy.core.multiarray._reconstruct])` or the `torch.serialization.safe_globals([numpy.core.multiarray._reconstruct])` context manager to allowlist this global if you trust this class/function.

Check the documentation of torch.load to learn more about types accepted by default with weights_only https://pytorch.org/docs/stable/generated/torch.load.html.

🎯 三、根本原因分析:PyTorch 2.6 改变默认行为

从 PyTorch 2.6 开始,torch.load 默认开启安全机制:

参数PyTorch ≤2.5 默认值PyTorch 2.6+ 默认值
weights_onlyFalseTrue

也就是说:

  • 新版本 torch.load() 默认只加载模型权重,不允许反序列化含自定义对象的 .pth 文件
  • 而老模型(如很多 .pth checkpoint)中使用了 numpy, OrderedDict, custom class
  • 所以反序列化失败,抛出 _pickle.UnpicklingError

✅ 四、解决方案一:直接设置 weights_only=False

若你信任模型来源(例如自己训练或官方 release),直接关闭安全限制:

import torch

# 加载模型 checkpoint
checkpoint = torch.load('model.pth', weights_only=False)

适用场景:

  • 自己训练的模型
  • 官方发布模型
  • 信任来源的模型(如 HuggingFace、MMDetection 官方)

🔒 五、解决方案二:启用反序列化白名单(更简单快捷)

对于不确定来源的模型,为了避免反序列化恶意代码,可以手动添加允许的反序列化类型。

✅ 这些代码可以直接放到运行脚本的顶部导入区,添加一次即可生效。比如在最终运行的 image_demo.py 文件的顶部

方式一:添加到全局

import torch
import numpy as np

torch.serialization.add_safe_globals([np.core.multiarray._reconstruct])
checkpoint = torch.load('model.pth')

方式二:使用上下文管理器(推荐)

import torch
import numpy as np

with torch.serialization.safe_globals([np.core.multiarray._reconstruct]):
    checkpoint = torch.load('model.pth')

🧱 六、MMDetection 安装流程(适配 50 系显卡 + CUDA 12.8)

✅ 1. 创建 conda 虚拟环境

conda create -n openmmlab python=3.9 -y
conda activate openmmlab

✅ 2. 安装 PyTorch(Nightly 版本支持 CUDA 12.8)

pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu128

确认安装成功:

python -c "import torch; print(torch.__version__); print(torch.cuda.is_available())"

✅ 3. 安装 MMCV(通过 OpenMIM 自动适配版本)

pip install -U openmim
mim install mmcv==2.1.0  #这是我唯一试出来可行的版本

✅ 4. 安装 MMDetection

git clone https://github.com/open-mmlab/mmdetection.git
cd mmdetection
pip install -r requirements/build.txt
pip install -v -e .

✅ 5. 运行 demo 检测测试

python demo/image_demo.py demo/demo.jpg \
    configs/faster_rcnn/faster_rcnn_r50_fpn_1x_coco.py \
    checkpoints/faster_rcnn_r50_fpn_1x_coco_20200130-047c8118.pth

🛠 七、推荐封装函数:兼容旧模型加载

def safe_load(path):
    import torch
    import numpy as np
    torch.serialization.add_safe_globals([np.core.multiarray._reconstruct])
    return torch.load(path)

使用方式:

checkpoint = safe_load('model.pth')

⚙ 八、推荐环境配置(50 系显卡最佳搭配)

组件推荐版本
显卡RTX 5090 / 5080 / 5070
CUDA✅ CUDA 12.8 及以上
Python✅ Python 3.9
PyTorch✅ PyTorch 2.8
MMCV✅ 2.1.0
MMDetection✅ 最新稳定版本

💡 九、经验总结

  • PyTorch 2.6 引入 weights_only=True 机制是出于安全考虑
  • 老模型容易踩坑,建议明确传参或手动配置白名单
  • CUDA 12.8 是未来趋势,MMDetection、MMCV 都要使用新版 PyTorch
  • 强烈建议使用 conda 管理环境、使用 OpenMIM 安装 mmcv

🚀 十、终极解决方案:全局 monkey patch torch.load(强烈推荐!)

如果你项目中使用了大量第三方库(如 MMDetection / HuggingFace / Detectron2),你不可能每次都手动加 weights_only=False,这时候最推荐的方式就是 —— 全局重写 PyTorch 的 torch.load 函数行为

✅ 原理:Monkey Patch

通过“猴子补丁”,我们可以在程序启动时一次性覆盖 PyTorch 的 torch.load 行为,让它永远默认 weights_only=False

✅ 补丁代码如下:

import torch

# 保存原始 torch.load 函数
_original_torch_load = torch.load

# 覆盖 torch.load,默认强制 weights_only=False
def patched_torch_load(*args, **kwargs):
    kwargs['weights_only'] = False
    return _original_torch_load(*args, **kwargs)

# 应用 monkey patch
torch.load = patched_torch_load

🔧 如何使用?

只需要把上面这段代码,放到你项目的入口脚本顶部,例如:

  • train.py
  • inference.py
  • tools/train.py(OpenMMLab 系项目)
  • demo/image_demo.py(MMDetection)

即可全局生效,兼容所有后续 torch.load() 调用,无需修改其他代码!


📦 Bonus:补丁 + 白名单合体方案(防止 numpy 报错)

import torch
import numpy as np

# 添加 numpy 支持
torch.serialization.add_safe_globals([np.core.multiarray._reconstruct])

# 重写 torch.load,自动关闭 weights_only
_original_torch_load = torch.load

def patched_torch_load(*args, **kwargs):
    kwargs['weights_only'] = False
    return _original_torch_load(*args, **kwargs)

torch.load = patched_torch_load

🔗 参考链接


如果这篇文章对你有帮助:

📌 点赞 👍 + 收藏 ⭐ + 留言 💬 支持我更新更多 PyTorch 实战内容!

评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值