timm加载模型create_model使用本地预训练模型
1.常规方式,从https://huggingface.co/上下载
注:国内使用你可能会遇到huggingface.co连接失败。
model = timm.create_model("resnet18", pretrained = True, num_classses = 2)
1-1. timm库中create_model函数的用法
1.最简单的用法
model = timm.create_model("resnet34")
2.查看可以直接创建的预训练模型列表
availableModels = timm.list_models(pretrained=True)
print(availableModels)
print(len(availableModels))
3.参数:pretrained = True
使用预训练模型,如果本地路径没有则会自动联网下载,第二次开始自动调用本地模型。
2. 使用本地的预训练模型
2-1. 国内镜像下载模型:https://hf-mirror.com/
因为huggingface.co connect error, 需要自己下载模型到本地再进行加载, 网上的方式千奇百怪,弄得很麻烦,还有搭梯子的,这里记录我觉得最快的方法。
进入huggingface国内镜像,搜索自己想要的模型,在files and versions 中下载bin文件。
https://hf-mirror.com/
2-2. 查找对应模型名称
如果不知道自己应该搜索哪个模型,提供一个方法:
直接运行
model = timm.create_model(“resnet18”, pretrained = True)
由于connectionError,会有一个报错显示
requests.exceptions.ConnectionError: (MaxRetryError(“HTTPSConnectionPool(host=‘huggingface.co’, port=443): Max retries exceeded with url: /timm/resnet18.a1_in1k/resolve/main/pytorch_model.bin (Caused by NewConnectionError(‘<urllib3.connection.HTTPSConnection object at 0x7f8e5c7977f0>: Failed to establish a new connection: [Errno 101] Network is unreachable’))”), ‘(Request ID: be47fa93-75d7-4ab8-8588-1b6cb75ac408)’)
url中显示了预训练模型的名称,即resnet18.a1_in1k,进入镜像网站搜索resnet18.a1_in1k即可
2-3.调用bin文件作为预训练模型
create_model 添加参数:pretrained_cfg_overlay, 解决
model_ft = timm.create_model('resnet18', pretrained=True, pretrained_cfg_overlay=dict(file='/home/xiaoxin/Documents/hc/-/bin/pytorch_model.bin'))