使用本地文件创建resnet50模型

使用本地文件创建resnet50模型

1. 问题分析

最近用到目标检测的模型DETR,但是在创建模型的时候却遇到模型无法创建的问题。本文记录一下解决该问题的过程。

检查原因发现是在创建模型的过程中,需要联网下载。

即便是我将facebook/detr-resnet-50的所有文件下载到本地,然后在from_pretrained时候指定本地的路径,仍然遇到了连接hf下载模型的问题。由于一些众所周知的原因,hf无法直接访问了,这就导致下载遇到了点问题。

需要下载的文件是resnet50的backbone,是没有被包含在detr的模型文件中的。
transformers中的modeling_detr.py中并没有给resnet的本地文件留入参,这就带来了很多不便。即便如此,我们还是可以在modeling_detr.py手动创建backbone,以避免联网下载。

2. 问题解决

首先还是需要先下载resnet50的权重(需科学上网):
https://huggingface.co/timm/resnet50.a1_in1k/tree/main

将这些文件放在目录(记作path_a)中。
然后修改transfomers模块中的源码transformers/models/detr/modeling_detr.py
大约340行:

    def __init__(self, config):
        super().__init__()

        self.config = config

        if config.use_timm_backbone:
            requires_backends(self, ["timm"])
            kwargs = {}
            if config.dilation:
                kwargs["output_stride"] = 16
            backbone = create_model(
                config.backbone,
                pretrained=config.use_pretrained_backbone,
                features_only=True,
                out_indices=(1, 2, 3, 4),
                in_chans=config.num_channels,
                **kwargs,
            )
        else:
            backbone = AutoBackbone.from_config(config.backbone_config)

修改为:

    def __init__(self, config):
        super().__init__()

        self.config = config
        # 从指定路径直接创建backbone
		import timm
		backbone = timm.create_model(
   			'resnet50',
    		pretrained=True,
			pretrained_cfg_overlay=dict(file='path_a/pytorch_model.bin'),  # 刚才下载模型保存的路径
    		features_only=True,
    		out_indices=(1, 2, 3, 4),
    		in_chans=3,
)
		# 原来的部分全都注释掉
		'''
        if config.use_timm_backbone:
            requires_backends(self, ["timm"])
            kwargs = {}
            if config.dilation:
                kwargs["output_stride"] = 16
            backbone = create_model(
                config.backbone,
                pretrained=config.use_pretrained_backbone,
                features_only=True,
                out_indices=(1, 2, 3, 4),
                in_chans=config.num_channels,
                **kwargs,
            )
        else:
            backbone = AutoBackbone.from_config(config.backbone_config)
        '''

修改之后再创建detr模型,就不会报错了:

processor = DetrImageProcessor.from_pretrained("your_path_to/detr-resnet-50/", revision="no_timm")
model = DetrForObjectDetection.from_pretrained("your_path_to/detr-resnet-50/", revision="no_timm")

其他用到了timm模块的hf模型,也可以用类似的方法解决联网下载的问题。

  • 3
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值