最近在使用timm,总结了一些方法,可能会对读者有帮助。
1. 首先是安装timm包
pip intall timm
2. 通过下面代码展示timm具有的模型名称,根据输出模型的名称选择自己需要的模型:
model_list = timm.list_models()
print(model_list)
3. 加载模型。
model = timm.create_model('convnext_base', pretrained=True, num_classes=2)
但是因为timm的升级,导致了国内无法连接到Hugging Face网站,没有办法使用手动下载预训练模型。总是出现这样的错误:
huggingface_hub.utils._errors.LocalEntryNotFoundError: Connection error, and we cannot find the requested files in the disk cache.
4. 针对上述无法连接的问题,采用这样的解决方案,首先
model = timm.create_model('convnext_base', num_classes=2, global_pool='')#pretrained=True,
print(model.default_cfg)#查看模型cfg
得到
{'url': 'https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_1k_224.pth',
科学上网,手动下载这个预训练模型。
pre_path = ' 下载预训练模型权重的文件路径'
model = timm.create_model('convnext_base', pretrained=True, num_classes=2, pretrained_cfg_overlay=dict(file=pre_path ))
这样就解决了,连接失败的问题,就可以正常使用手动下载的预训练模型了。