【pytorch】Timm库从本地权重文件初始化预训练模型

由于github连接不稳定,导致有时候Timm库自动从URL下载模型失败。这时候如果要使用预训练模型,就需要提前下载到本地,但是timm没有直接的对应接口,需要做一些调整。

读取整个模型的情况(feature_only=False)

这种情况比较简单,指定create_model的checkpoint_path为本地预训练模型。注意:pretrained=True会自动尝试下载,如果要加载本地模型,需要设置成false。以efficientnet为例:

import timm

model = timm.create_model(
                                "tf_efficientnetv2_s_in21ft1k",
                                pretrained=False,
                                checkpoint_path=#本地权重的路径
                            )

读取特征提取器的情况(features_only=True)

若设置create_model的features_only=True,则会从原模型剪枝生成一个feature模型。而checkpoints_path选项是构建完模型后,再去读取预训练模型权重,这会导致目标权重和模型不匹配。
因此,如果想要设置features_only=True,就必须设置pretrained=True,通过pretrained_cfg这个参数自定义内部模型加载,将url修改为本地文件:

import timm
from timm.models.efficientnet import _cfg

config = _cfg(url='', file='tf_efficientnetv2_s_21ft1k-d7dafa41.pth') #file为本地文件路径

model = timm.create_model(
                                "tf_efficientnetv2_s_in21ft1k",
                                pretrained=True,
                                features_only=True,
                                pretrained_cfg=config
                            )

看到没人发,就记录一下。

  • 32
    点赞
  • 41
    收藏
    觉得还不错? 一键收藏
  • 8
    评论
评论 8
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值