Python 3.11.3 (main, Apr 7 2023, 19:25:52) [Clang 14.0.0 (clang-1400.0.29.202)] on darwin
Type "help", "copyright", "credits" or "license" for more information.
正常情况下,我们会使用 torch.save
保存模型的 state_dict
。但我们也可以 torch.save
保存一个自定义类型对象,例如
import torch
import torch.nn as nn
class Module(nn.Module):
def __init__(self) -> None:
self._one = 1
torch.save(Module(), 'module.pth')
在读取 module.pth
时,可能会遇到 AttributeError
import torch
torch.load('module.pth')
Traceback (most recent call last):
File "/Users/bytedance/Developer/todd/load.py", line 3, in <module>
torch.load('module.pth')
File "/Users/bytedance/.local/share/virtualenvs/todd-ARrcnwyq/lib/python3.11/site-packages/torch/serialization.py", line 809, in load
return _load(opened_zipfile, map_location, pickle_module, **pickle_load_args)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/bytedance/.local/share/virtualenvs/todd-ARrcnwyq/lib/python3.11/site-packages/torch/serialization.py", line 1172, in _load
result = unpickler.load()
^^^^^^^^^^^^^^^^
File "/Users/bytedance/.local/share/virtualenvs/todd-ARrcnwyq/lib/python3.11/site-packages/torch/serialization.py", line 1165, in find_class
return super().find_class(mod_name, name)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AttributeError: Can't get attribute 'Module' on <module '__main__' from '/Users/bytedance/Developer/todd/load.py'>
这是因为 torch.save
底层通过 pickle
实现,而 pickle
在保存自定义类型对象时不会保存其类型定义。用户需要保证 torch.load
时,自定义类型可访问,以便构造被保存的对象。也就是说,如果我们将 Module
引用到当前命名空间,就可以正常加载 module.pth
了
import torch
from save import Module
torch.load('module.pth')
但是有些情况下,我们无法访问某些自定义类型,也不希望恢复被保存的对象,只想知道被保存的对象存储了哪些数据,可以用下面的方法
import torch
class Module:
def __init__(self) -> None:
# in case __setstate__ is not called
self._state = None
def __setstate__(self, state):
# whenever state is not empty, __setstate__ is called
self._state = state
module = torch.load('module.pth')
print(module._state)
{'_one': 1}
但是如果自定义类型是从其他位置 import
得到的,例如
# module.py
import torch.nn as nn
class Module(nn.Module):
def __init__(self) -> None:
self._one = 1
# save.py
import torch
from module import Module
torch.save(Module(), 'module.pth')
torch.load
会先尝试 import
相应的模块,如果不存在就会报错
Traceback (most recent call last):
File "/Users/bytedance/Developer/todd/load.py", line 13, in <module>
module = torch.load('module.pth')
^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/bytedance/.local/share/virtualenvs/todd-ARrcnwyq/lib/python3.11/site-packages/torch/serialization.py", line 809, in load
return _load(opened_zipfile, map_location, pickle_module, **pickle_load_args)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/bytedance/.local/share/virtualenvs/todd-ARrcnwyq/lib/python3.11/site-packages/torch/serialization.py", line 1172, in _load
result = unpickler.load()
^^^^^^^^^^^^^^^^
File "/Users/bytedance/.local/share/virtualenvs/todd-ARrcnwyq/lib/python3.11/site-packages/torch/serialization.py", line 1165, in find_class
return super().find_class(mod_name, name)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ModuleNotFoundError: No module named 'module'
我们可以 mock
相应模块
import sys
from unittest.mock import Mock
import torch
sys.modules['module'] = Mock()
torch.load('module.pth')
Traceback (most recent call last):
File "/Users/bytedance/Developer/todd/load.py", line 14, in <module>
module = torch.load('module.pth')
^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/bytedance/.local/share/virtualenvs/todd-ARrcnwyq/lib/python3.11/site-packages/torch/serialization.py", line 809, in load
return _load(opened_zipfile, map_location, pickle_module, **pickle_load_args)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/bytedance/.local/share/virtualenvs/todd-ARrcnwyq/lib/python3.11/site-packages/torch/serialization.py", line 1172, in _load
result = unpickler.load()
^^^^^^^^^^^^^^^^
_pickle.UnpicklingError: NEWOBJ class argument must be a type, not Mock
出现这个问题,是因为 Mock
具有递归创建的特性。我们可以手动修改
import sys
from unittest.mock import Mock
import torch
class Module:
def __init__(self) -> None:
self._state = None
def __setstate__(self, state):
self._state = state
sys.modules['module'] = Mock()
sys.modules['module'].Module = Module
module = torch.load('module.pth')
print(module._state)
{'_one': 1}