解决torch.hub.load加载网络模型异常

文章讲述了在使用torch.hub.load加载PyTorch模型时遇到的问题,如网络连接不稳定导致的下载失败,以及如何通过本地加载解决。同时介绍了torch.hub.load的功能和注意事项,包括依赖性、网络需求和存储资源管理。
摘要由CSDN通过智能技术生成

1 torch.hub.load 加载网络模型错误

通过网络使用torch.hub.load加载模型代码如下:

self.model = torch.hub.load("facebookresearch/dinov2", 'dinov2_vits14', source='github').to(self.device)

运行网上的项目,经常会卡住或者超时,原因是 torch.hub.load 默认会去网上找模型,而github经常是不可访问的(需要走代理),从而导致网络异常,错误如下:

Traceback (most recent call last):
  File "/opt/pa_retrieve/preprocessor/remove_redundant_image.py", line 4, in <module>
    from model.dinov2_embeding_small import dinov2_embeding_small
  File "/opt/pa_retrieve/model/dinov2_embeding_small.py", line 42, in <module>
    dinov2_embeding_small = Dinov2EmbedingSmall()
  File "/opt/pa_retrieve/model/dinov2_embeding_small.py", line 20, in __init__
    self.model = torch.hub.load("facebookresearch/dinov2", 'dinov2_vits14', source='github').to(self.device)
  File "/root/anaconda3/envs/pa/lib/python3.9/site-packages/torch/hub.py", line 555, in load
    repo_or_dir = _get_cache_or_reload(repo_or_dir, force_reload, trust_repo, "load",
  File "/root/anaconda3/envs/pa/lib/python3.9/site-packages/torch/hub.py", line 199, in _get_cache_or_reload
    repo_owner, repo_name, ref = _parse_repo_info(github)
  File "/root/anaconda3/envs/pa/lib/python3.9/site-packages/torch/hub.py", line 142, in _parse_repo_info
    with urlopen(f"https://github.com/{repo_owner}/{repo_name}/tree/main/"):
  File "/root/anaconda3/envs/pa/lib/python3.9/urllib/request.py", line 214, in urlopen
    return opener.open(url, data, timeout)
  File "/root/anaconda3/envs/pa/lib/python3.9/urllib/request.py", line 517, in open
    response = self._open(req, data)
  File "/root/anaconda3/envs/pa/lib/python3.9/urllib/request.py", line 534, in _open
    result = self._call_chain(self.handle_open, protocol, protocol +
  File "/root/anaconda3/envs/pa/lib/python3.9/urllib/request.py", line 494, in _call_chain
    result = func(*args)
  File "/root/anaconda3/envs/pa/lib/python3.9/urllib/request.py", line 1389, in https_open
    return self.do_open(http.client.HTTPSConnection, req,
  File "/root/anaconda3/envs/pa/lib/python3.9/urllib/request.py", line 1350, in do_open
    r = h.getresponse()
  File "/root/anaconda3/envs/pa/lib/python3.9/http/client.py", line 1377, in getresponse
    response.begin()
  File "/root/anaconda3/envs/pa/lib/python3.9/http/client.py", line 320, in begin
    version, status, reason = self._read_status()
  File "/root/anaconda3/envs/pa/lib/python3.9/http/client.py", line 289, in _read_status
    raise RemoteDisconnected("Remote end closed connection without"
http.client.RemoteDisconnected: Remote end closed connection without response

2 torch.hub.load加载本地模型

通过代理下载模型和代码,模型存放在如下目录下:

/root/.cache/torch/hub/checkpoints/

工程代码存放在如下目录下:

/root/.cache/torch/hub/facebookresearch_dinov2_main/

更改模型加载的代码为本地加载,代码如下:

self.model = torch.hub.load('/root/.cache/torch/hub/facebookresearch_dinov2_main', 'dinov2_vits14', trust_repo=True, source='local').to(self.device)

再次运行程序,模型加载成功。

3 torch.hub.load

torch.hub.load是一个强大的工具,为PyTorch用户提供了快速访问和使用预训练模型的能力。它简化了模型加载过程,使得用户可以更加专注于模型的应用和开发,而不是从头开始训练模型。通过PyTorch Hub,PyTorch社区正在建立一个共享知识和资源的平台,是 PyTorch 生态系统的一个重要组成部分,特别适合于那些希望快速实现机器学习和深度学习原型的研究人员和开发者。

3.1 基本概念

  • PyTorch Hub:PyTorch Hub 是一个预训练模型的集合库,旨在促进模型的发现、共享和重用。通过提供一个集中的地方来存储经过验证的模型,PyTorch Hub 使得访问和共享高质量的模型变得简单。

  • torch.hub.load:这是一个用于加载 PyTorch Hub 中模型的函数。它通过简单的API调用,使用户能够轻松地访问和使用这些预训练模型。

3.2 功能和优势

  • 简化模型访问torch.hub.load 使得访问和加载预训练模型变得非常简单,通常只需要一行代码。

  • 广泛的模型支持:PyTorch Hub 支持各种模型,包括图像分类、文本处理、生成模型等。

  • 版本控制和可复现性:PyTorch Hub 提供模型的版本控制功能,确保结果的可复现性。

  • 社区贡献:开发者和研究人员可以贡献自己的模型到 PyTorch Hub,促进知识共享和合作。

3.3 使用场景

  • 快速原型设计:研究人员和开发者可以快速加载预训练模型来验证想法和构建原型。

  • 性能基准测试:通过使用标准的预训练模型,用户可以对自己的模型进行性能对比。

  • 教学和学习:教师和学生可以使用torch.hub.load来访问先进的模型,加深对深度学习的理解。

3.4 如何使用

使用torch.hub.load加载模型的一般流程如下:

  • 导入 PyTorch Hub:首先需要导入 torch.hub 库。

import torch.hub 
  • 加载模型:使用torch.hub.load函数加载模型。你需要指定仓库的 GitHub 地址、模型的名称以及想要加载的模型版本。

model = torch.hub.load('pytorch/vision:v0.6.0', 'resnet18', pretrained=True) 

在这个例子中,我们从 PyTorch 的 vision 仓库加载了预训练的 ResNet18 模型。

  • 使用模型:一旦模型被加载,它就可以像普通的 PyTorch 模型一样被使用。

output = model(input_data) 

这里input_data是你想要处理的数据。

3.5 注意事项

  • 依赖项:加载某些模型可能需要安装额外的依赖项。

  • 网络连接:由于模型是从互联网下载的,因此需要稳定的网络连接。

  • 存储空间:预训练模型可能需要较大的磁盘空间。

  • 计算资源:一些模型特别是大型的模型可能需要较强的计算资源,如GPU加速。

评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

智慧医疗

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值