torch.hub.load()把联网加载修改为本地加载模型

1. torch.load()

torch.load函数用于从磁盘加载已保存的模型或张量,以便进行后续的操作。这也是我们常用的一种导入预训练模型的方式,可以使用以下方式调用该函数:

model = torch.load('model.pth')

其中,model.pth就是我们存放模型的路径。

2.  torch.hub.load()

最近在复现某一个关于yolo的项目中遇到了这个方法,从该方法的hub可以看出,它在每次加载模型时都要联网进行加载。比如:

model = torch.hub.load(
            "ultralytics/yolov5",
            "custom",
            path=f"{local_model_path}/{model_name}",
            device=device,
            force_reload=[True if "refresh_yolov5" in opt else False][0],
            _verbose=True,
        )

其中custom表示自定义的模型,path是本地权重文件的路径,而"ultralytics/yolov5"表示该load方法每次加载模型时,都会访问到GitHub - ultralytics/yolov5: YOLOv5 🚀 in PyTorch > ONNX > CoreML > TFLite这个网址。不过有些时候国内加载github没有那么稳定,就会导致这个load方法经常报“远程连接失败”的错误。

3. 如何把torch.hub.load()改为每次从本地加载?

1)将所要加载的存储库直接搬到项目中来

比如我需要的存储库在GitHub - ultralytics/yolov5: YOLOv5 🚀 in PyTorch > ONNX > CoreML > TFLite,就可以直接访问该github网站把整个包克隆下来,放到项目中来(我放在了根目录下)。

2)修改hub.load代码

修改代码如下:

model = torch.hub.load(
            "./ultralytics_yolov5_master",
            "custom",
            path=f"{local_model_path}/{model_name}",
            device=device,
            source='local',
            force_reload=[True if "refresh_yolov5" in opt else False][0],
            _verbose=True,
        )

主要是两处发生了变化,一个是增加了参数source='local',指明我们是要从本地加载而不是联网加载(因为默认是source='github'),另外就是第一个参数中的路径(即加载路径)发生了变化,因为我们在第一步中已经将存储库拷贝到本地项目包的根目录下了。

到这里,之后再运行项目就会默认从本地加载啦。(>_<  联网加载真的太折磨人了)

---------------------------------------------------------------------------------------------------------------------------------

新人发帖,多多关照 ~

  • 10
    点赞
  • 25
    收藏
    觉得还不错? 一键收藏
  • 4
    评论
评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值