前言
本文介绍的是在网络条件不好的情况下,直接跑.from_pretrained()下载不了权重的解决办法,本文以AutoencoderKL.from_pretrained(model_key, subfolder=“vae”)为例。
注:本文仅提供一个找到url的思路,需要一定的python功底。
一、huggingface是什么?
抱脸是一个提供预训练模型和数据集的开源平台,可能需要科学上网。
二、怎么找要下载文件的url
先找到你python安装huggingface_hub的位置,比如我的位置在:C:\Users\zkzou\AppData\Local\Programs\Python\Python38\Lib\site-packages\huggingface_hub,进入该位置并打开file_download.py文件,然后在hf_huf_download函数的这一行加入print(“downloading url is:”,url)语句就行,大概如下:
1230 url = hf_hub_url(repo_id, filename, repo_type=repo_type, revision=revision, endpoint=endpoint)
1231 print("downloading url is:",url)
1232 headers = build_hf_headers(
1233 token=token,
1234 library_name=library_name,
1235 library_version=library_version,
1236 user_agent=user_agent,
1237 )
log就能打印出url来,例如本文打印的是:
downloading url is: https://huggingface.co/stabilityai/stable-diffusion-2-base/resolve/main/vae/config.json
然后浏览器打开https://huggingface.co/stabilityai/stable-diffusion-2-base/resolve/main/vae,把这个.json文件下载下来。
那么如果打印不出来呢?
别急,只要在hf_hub_download()这个函数下面加一个
import pdb; pdb.set_trace()
然后一直按n往后走,走到你觉得是url的代码那里打印一下变量就行了。
三、下载好后放哪
这个主要看.from_pretrained函数后面跟的地址,一般情况下是丢到.cache文件夹里,比如windows系统是“C:\Users\zkzou2.cache\huggingface\hub”下,Linux一般是“/user/.cache/huggingface/hub”下。如果你是下huggingface的文件且当前运行的代码有“./huggingface/hub”文件夹,则有可能丢到当前的“./huggingface/hub”文件夹里。也可以指定路径,比如本文的代码如下:
model_key = r"stabilityai/stable-diffusion-2-base"
AutoencoderKL.from_pretrained(model_key, subfolder="vae")
只要在当前目录下创建./stabilityai/stable-diffusion-2-base,然后把下载好的文件丢进去就行了。
如果再跑一遍,log还提示少了xxx,就再从上一步那个网站下载xxx并放进去就行了。
注:本文只是拿AutoencoderKL为例,有些包比如vocos没有model_key这个参数,你可以直接去 XXX\site-packages\vocos\pretrained.py里的from_pretrain()函数中强行指定路径,如果你不想改变源码的功能性,可以和我这样改:
修改前:
@classmethod
def from_pretrained(cls, repo_id: str, revision: Optional[str] = None) -> Vocos:
config_path = hf_hub_download(repo_id=repo_id, filename="config.yaml", revision=revision)
model_path = hf_hub_download(repo_id=repo_id, filename="pytorch_model.bin", revision=revision)
...
修改后:
def from_pretrained(cls, repo_id: str, revision: Optional[str] = None) -> Vocos:
config_path = 'vocos/config.yaml' if os.path.isfile('vocos/config.yaml') else hf_hub_download(repo_id=repo_id, filename="config.yaml", revision=revision)
model_path = 'vocos/pytorch_model.bin' if os.path.isfile('vocos/pytorch_model.bin') else hf_hub_download(repo_id=repo_id, filename="pytorch_model.bin", revision=revision)
...
其中’vocos/config.yaml’ 和’vocos/pytorch_model.bin’代表工程路径下如果存在vocos文件夹且文件夹中有配置文件和权重的话就直接读取,否则就从网上下载。之所以写成xxx = xxx if xxx else xxx是为了不破坏源代码的行数。