1 torch.hub.load 加载网络模型错误
通过网络使用torch.hub.load加载模型代码如下:
self.model = torch.hub.load("facebookresearch/dinov2", 'dinov2_vits14', source='github').to(self.device)
运行网上的项目,经常会卡住或者超时,原因是 torch.hub.load 默认会去网上找模型,而github经常是不可访问的(需要走代理),从而导致网络异常,错误如下:
Traceback (most recent call last):
File "/opt/pa_retrieve/preprocessor/remove_redundant_image.py", line 4, in <module>
from model.dinov2_embeding_small import dinov2_embeding_small
File "/opt/pa_retrieve/model/dinov2_embeding_small.py", line 42, in <module>
dinov2_embeding_small = Dinov2EmbedingSmall()
File "/opt/pa_retrieve/model/dinov2_embeding_small.py", line 20, in __init__
self.model = torch.hub.load("facebookresearch/dinov2", 'dinov2_vits14', source='github').to(self.device)
File "/root/anaconda3/envs/pa/lib/python3.9/site-packages/torch/hub.py", line 555, in load
repo_or_dir = _get_cache_or_reload(repo_or_dir, force_reload, trust_repo, "load",
File "/root/anaconda3/envs/pa/lib/python3.9/site-packages/torch/hub.py", line 199, in _get_cache_or_reload
repo_owner, repo_name, ref = _parse_repo_info(github)
File "/root/anaconda3/envs/pa/lib/python3.9/site-packages/torch/hub.py", line 142, in _parse_repo_info
with urlopen(f"https://github.com/{repo_owner}/{repo_name}/tree/main/"):
File "/root/anaconda3/envs/pa/lib/python3.9/urllib/request.py", line 214, in urlopen
return opener.open(url, data, timeout)
File "/root/anaconda3/envs/pa/lib/python3.9/urllib/request.py", line 517, in open
response = self._open(req, data)
File "/root/anaconda3/envs/pa/lib/python3.9/urllib/request.py", line 534, in _open
result = self._call_chain(self.handle_open, protocol, protocol +
File "/root/anaconda3/envs/pa/lib/python3.9/urllib/request.py", line 494, in _call_chain
result = func(*args)
File "/root/anaconda3/envs/pa/lib/python3.9/urllib/request.py", line 1389, in https_open
return self.do_open(http.client.HTTPSConnection, req,
File "/root/anaconda3/envs/pa/lib/python3.9/urllib/request.py", line 1350, in do_open
r = h.getresponse()
File "/root/anaconda3/envs/pa/lib/python3.9/http/client.py", line 1377, in getresponse
response.begin()
File "/root/anaconda3/envs/pa/lib/python3.9/http/client.py", line 320, in begin
version, status, reason = self._read_status()
File "/root/anaconda3/envs/pa/lib/python3.9/http/client.py", line 289, in _read_status
raise RemoteDisconnected("Remote end closed connection without"
http.client.RemoteDisconnected: Remote end closed connection without response
2 torch.hub.load加载本地模型
通过代理下载模型和代码,模型存放在如下目录下:
/root/.cache/torch/hub/checkpoints/
工程代码存放在如下目录下:
/root/.cache/torch/hub/facebookresearch_dinov2_main/
更改模型加载的代码为本地加载,代码如下:
self.model = torch.hub.load('/root/.cache/torch/hub/facebookresearch_dinov2_main', 'dinov2_vits14', trust_repo=True, source='local').to(self.device)
再次运行程序,模型加载成功。
3 torch.hub.load
torch.hub.load
是一个强大的工具,为PyTorch用户提供了快速访问和使用预训练模型的能力。它简化了模型加载过程,使得用户可以更加专注于模型的应用和开发,而不是从头开始训练模型。通过PyTorch Hub,PyTorch社区正在建立一个共享知识和资源的平台,是 PyTorch 生态系统的一个重要组成部分,特别适合于那些希望快速实现机器学习和深度学习原型的研究人员和开发者。
3.1 基本概念
-
PyTorch Hub:PyTorch Hub 是一个预训练模型的集合库,旨在促进模型的发现、共享和重用。通过提供一个集中的地方来存储经过验证的模型,PyTorch Hub 使得访问和共享高质量的模型变得简单。
-
torch.hub.load:这是一个用于加载 PyTorch Hub 中模型的函数。它通过简单的API调用,使用户能够轻松地访问和使用这些预训练模型。
3.2 功能和优势
-
简化模型访问:
torch.hub.load
使得访问和加载预训练模型变得非常简单,通常只需要一行代码。 -
广泛的模型支持:PyTorch Hub 支持各种模型,包括图像分类、文本处理、生成模型等。
-
版本控制和可复现性:PyTorch Hub 提供模型的版本控制功能,确保结果的可复现性。
-
社区贡献:开发者和研究人员可以贡献自己的模型到 PyTorch Hub,促进知识共享和合作。
3.3 使用场景
-
快速原型设计:研究人员和开发者可以快速加载预训练模型来验证想法和构建原型。
-
性能基准测试:通过使用标准的预训练模型,用户可以对自己的模型进行性能对比。
-
教学和学习:教师和学生可以使用
torch.hub.load
来访问先进的模型,加深对深度学习的理解。
3.4 如何使用
使用torch.hub.load
加载模型的一般流程如下:
-
导入 PyTorch Hub:首先需要导入 torch.hub 库。
import torch.hub
-
加载模型:使用
torch.hub.load
函数加载模型。你需要指定仓库的 GitHub 地址、模型的名称以及想要加载的模型版本。
model = torch.hub.load('pytorch/vision:v0.6.0', 'resnet18', pretrained=True)
在这个例子中,我们从 PyTorch 的 vision 仓库加载了预训练的 ResNet18 模型。
-
使用模型:一旦模型被加载,它就可以像普通的 PyTorch 模型一样被使用。
output = model(input_data)
这里input_data
是你想要处理的数据。
3.5 注意事项
-
依赖项:加载某些模型可能需要安装额外的依赖项。
-
网络连接:由于模型是从互联网下载的,因此需要稳定的网络连接。
-
存储空间:预训练模型可能需要较大的磁盘空间。
-
计算资源:一些模型特别是大型的模型可能需要较强的计算资源,如GPU加速。