问题
当我们想使用官方预训练好的CLIP模型时需要用到以下方法来加载模型:
model, preprocess = clip.load("ViT-B/32", device=device)
但是,因为网络问题往往导致在下载模型参数时发生错误。
解决思路
因为通过代码下载时会报错,所以考虑能否预先下载好权重文件,再通过代码加载。
找到clip.py
这个文件,其中的load
函数部分定义如下:
def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit: bool = False, download_root: str = None):
"""Load a CLIP model
Parameters
----------
name : str
A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict
device : Union[str, torch.device]
The device to put the loaded model
jit : bool
Whether to load the optimized JIT model or more hackable non-JIT model (default).
download_root: str
path to download the model files; by default, it uses "~/.cache/clip"
Returns
-------
model : torch.nn.Module
The CLIP model
preprocess : Callable[[PIL.Image], torch.Tensor]
A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
"""
if name in _MODELS:
model_path = _download(_MODELS[name], download_root or os.path.expanduser("~/.cache/clip"))
elif os.path.isfile(name):
model_path = name
else:
raise RuntimeError(f"Model {name} not found; available models = {available_models()}")
with open(model_path, 'rb') as opened_file:
try:
# loading JIT archive
model = torch.jit.load(opened_file, map_location=device if jit else "cpu").eval()
state_dict = None
except RuntimeError:
# loading saved state dict
if jit:
warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead")
jit = False
state_dict = torch.load(opened_file, map_location="cpu")
从中我们可以看到A model name listed by clip.available_models(), or the path to a model checkpoint containing the state_dict
,可以将clip.load
中的name
参数改为模型预训练文件的路径,且不要是_MODELS
这个字典中的关键字就不会去下载了。那么之后就是从哪里下再权重文件的问题了。我们看_MODELS
_MODELS = {
"RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt",
"RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt",
"RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt",
"RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt",
"RN50x64": "https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt",
"ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt",
"ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt",
"ViT-L/14": "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt",
"ViT-L/14@336px": "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt",
}
从中就可以找到各个模型所对应的权重文件的下载地址。之后通过
model, preprocess = clip.load(checkpoint_path, device=device)
即可离线加载模型。