问题报错如下: AttributeError Traceback (most recent call last) Cell In[10], line 1 ----> 1 model = torch.load('fruit30_pytorch_light.pth') 2 model = model.eval().to(device) File E:\Anaconda\envs\pytorch\lib\site-packages\torch\serialization.py:712, in load(f, map_location, pickle_module, **pickle_load_args) 710 opened_file.seek(orig_position) 711 return torch.jit.load(opened_file) --> 712 return _load(opened_zipfile, map_location, pickle_module, **pickle_load_args) 713 return _legacy_load(opened_file, map_location, pickle_module, **pickle_load_args) File E:\Anaconda\envs\pytorch\lib\site-packages\torch\serialization.py:1049, in _load(zip_file, map_location, pickle_module, pickle_file, **pickle_load_args) 1047 unpickler = UnpicklerWrapper(data_file, **pickle_load_args) 1048 unpickler.persistent_load = persistent_load -> 1049 result = unpickler.load() 1051 torch._utils._validate_loaded_sparse_tensors() 1053 return result File E:\Anaconda\envs\pytorch\lib\site-packages\torch\serialization.py:1042, in _load.<locals>.UnpicklerWrapper.find_class(self, mod_name, name) 1040 pass 1041 mod_name = load_module_mapping.get(mod_name, mod_name) -> 1042 return super().find_class(mod_name, name) AttributeError: Can't get attribute 'Net' on <module '__main__'>
这是在使用torch训练完自己写的Net类的model之后,如下所示:
torch.save(model, 'checkpoints/fruit30_pytorch_light.pth')
重新在预测使用阶段中应用
model = torch.load('fruit30_pytorch_light.pth')
model = model.eval().to(device)
出现此类问题是说明在新的应用py中,没有找到你自己定义这个model,所以要把整个自己写的model这个类复制到上面这个应用代码的上面即可!