torch.load 报错 ModuleNotFoundError 或 AttributeError

文章讨论了在PyTorch中使用torch.save保存自定义类型对象时,由于pickle机制不保存类型定义,导致torch.load时可能出现AttributeError。解决方法包括确保加载时自定义类型是可访问的,或者在无法访问类型时利用__setstate__方法查看保存的数据。当涉及跨模块的自定义类型时,可以使用mock模块来处理找不到相应模块的问题。
摘要由CSDN通过智能技术生成
Python 3.11.3 (main, Apr  7 2023, 19:25:52) [Clang 14.0.0 (clang-1400.0.29.202)] on darwin
Type "help", "copyright", "credits" or "license" for more information.

正常情况下,我们会使用 torch.save 保存模型的 state_dict 。但我们也可以 torch.save 保存一个自定义类型对象,例如

import torch
import torch.nn as nn
class Module(nn.Module):
    def __init__(self) -> None:
        self._one = 1
torch.save(Module(), 'module.pth')

在读取 module.pth 时,可能会遇到 AttributeError

import torch
torch.load('module.pth')
Traceback (most recent call last):
  File "/Users/bytedance/Developer/todd/load.py", line 3, in <module>
    torch.load('module.pth')
  File "/Users/bytedance/.local/share/virtualenvs/todd-ARrcnwyq/lib/python3.11/site-packages/torch/serialization.py", line 809, in load
    return _load(opened_zipfile, map_location, pickle_module, **pickle_load_args)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/bytedance/.local/share/virtualenvs/todd-ARrcnwyq/lib/python3.11/site-packages/torch/serialization.py", line 1172, in _load
    result = unpickler.load()
             ^^^^^^^^^^^^^^^^
  File "/Users/bytedance/.local/share/virtualenvs/todd-ARrcnwyq/lib/python3.11/site-packages/torch/serialization.py", line 1165, in find_class
    return super().find_class(mod_name, name)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AttributeError: Can't get attribute 'Module' on <module '__main__' from '/Users/bytedance/Developer/todd/load.py'>

这是因为 torch.save 底层通过 pickle 实现,而 pickle 在保存自定义类型对象时不会保存其类型定义。用户需要保证 torch.load 时,自定义类型可访问,以便构造被保存的对象。也就是说,如果我们将 Module 引用到当前命名空间,就可以正常加载 module.pth

import torch
from save import Module
torch.load('module.pth')

但是有些情况下,我们无法访问某些自定义类型,也不希望恢复被保存的对象,只想知道被保存的对象存储了哪些数据,可以用下面的方法

import torch
class Module:
    def __init__(self) -> None:
        # in case __setstate__ is not called
        self._state = None
    def __setstate__(self, state):
        # whenever state is not empty, __setstate__ is called
        self._state = state
module = torch.load('module.pth')
print(module._state)
{'_one': 1}

但是如果自定义类型是从其他位置 import 得到的,例如

# module.py
import torch.nn as nn
class Module(nn.Module):
    def __init__(self) -> None:
        self._one = 1

# save.py
import torch
from module import Module
torch.save(Module(), 'module.pth')

torch.load 会先尝试 import 相应的模块,如果不存在就会报错

Traceback (most recent call last):
  File "/Users/bytedance/Developer/todd/load.py", line 13, in <module>
    module = torch.load('module.pth')
             ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/bytedance/.local/share/virtualenvs/todd-ARrcnwyq/lib/python3.11/site-packages/torch/serialization.py", line 809, in load
    return _load(opened_zipfile, map_location, pickle_module, **pickle_load_args)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/bytedance/.local/share/virtualenvs/todd-ARrcnwyq/lib/python3.11/site-packages/torch/serialization.py", line 1172, in _load
    result = unpickler.load()
             ^^^^^^^^^^^^^^^^
  File "/Users/bytedance/.local/share/virtualenvs/todd-ARrcnwyq/lib/python3.11/site-packages/torch/serialization.py", line 1165, in find_class
    return super().find_class(mod_name, name)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ModuleNotFoundError: No module named 'module'

我们可以 mock 相应模块

import sys
from unittest.mock import Mock
import torch
sys.modules['module'] = Mock()
torch.load('module.pth')
Traceback (most recent call last):
  File "/Users/bytedance/Developer/todd/load.py", line 14, in <module>
    module = torch.load('module.pth')
             ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/bytedance/.local/share/virtualenvs/todd-ARrcnwyq/lib/python3.11/site-packages/torch/serialization.py", line 809, in load
    return _load(opened_zipfile, map_location, pickle_module, **pickle_load_args)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/bytedance/.local/share/virtualenvs/todd-ARrcnwyq/lib/python3.11/site-packages/torch/serialization.py", line 1172, in _load
    result = unpickler.load()
             ^^^^^^^^^^^^^^^^
_pickle.UnpicklingError: NEWOBJ class argument must be a type, not Mock

出现这个问题,是因为 Mock 具有递归创建的特性。我们可以手动修改

import sys
from unittest.mock import Mock
import torch
class Module:
    def __init__(self) -> None:
        self._state = None
    def __setstate__(self, state):
        self._state = state
sys.modules['module'] = Mock()
sys.modules['module'].Module = Module
module = torch.load('module.pth')
print(module._state)
{'_one': 1}
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

LutingWang

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值