由于github连接不稳定,导致有时候Timm库自动从URL下载模型失败。这时候如果要使用预训练模型,就需要提前下载到本地,但是timm没有直接的对应接口,需要做一些调整。
读取整个模型的情况(feature_only=False)
这种情况比较简单,指定create_model的checkpoint_path为本地预训练模型。注意:pretrained=True会自动尝试下载,如果要加载本地模型,需要设置成false。以efficientnet为例:
import timm
model = timm.create_model(
"tf_efficientnetv2_s_in21ft1k",
pretrained=False,
checkpoint_path=#本地权重的路径
)
读取特征提取器的情况(features_only=True)
若设置create_model的features_only=True,则会从原模型剪枝生成一个feature模型。而checkpoints_path选项是构建完模型后,再去读取预训练模型权重,这会导致目标权重和模型不匹配。
因此,如果想要设置features_only=True,就必须设置pretrained=True,通过pretrained_cfg这个参数自定义内部模型加载,将url修改为本地文件:
import timm
from timm.models.efficientnet import _cfg
config = _cfg(url='', file='tf_efficientnetv2_s_21ft1k-d7dafa41.pth') #file为本地文件路径
model = timm.create_model(
"tf_efficientnetv2_s_in21ft1k",
pretrained=True,
features_only=True,
pretrained_cfg=config
)
看到没人发,就记录一下。