python系列&deep_study系列:模型下载的几种方式




模型下载的几种方式

问题描述

作为一名自然语言处理算法人员,hugging face开源的transformers包在日常的使用十分频繁。在使用过程中,每次使用新模型的时候都需要进行下载。如果训练用的服务器有网,那么可以通过调用from_pretrained方法直接下载模型。但是就本人的体验来看,这种方式尽管方便,但还是会有两方面的问题:

  • 如果网络很不好,模型下载时间会很久,一个小模型下载几个小时也很常见

  • 如果换了训练服务器,又要重新下载。

这里可能大家会疑惑,为什么不能把当前下载好的模型迁移过去,我们可以看下通过from_pretrained保存的文件(一般在~/.cache/huggingface/transformers文件夹下)模型文件

!https://s3-us-west-2.amazonaws.com/secure.notion-static.com/79042590-35ff-4181-9c70-1db5bf713183/v2-6a9100687e302faffa91950ac21102f1_720w.jpg

推荐方式

transformers下载 推荐

from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq

model_name = "openai/whisper-large-v2"
processor = AutoProcessor.from_pretrained(model_name )
model = AutoModelForSpeechSeq2Seq.from_pretrained(model_name )

Hugging Face Hub 下载 推荐

pip install huggingface_hub

from huggingface_hub import snapshot_download

snapshot_download(repo_id="bert-base-chinese")
# allow_regex和ignore_regex两个参数,简单来说前者是对指定的匹配项进行下载,后者是忽略指定的匹配项,下载其余部分
snapshot_download(repo_id="bert-base-chinese", ignore_regex=["*.h5", "*.ot", "*.msgpack"])

requests 下载

import os
import json
import requests
from uuid import uuid4
from tqdm import tqdm

SESSIONID = uuid4().hex

VOCAB_FILE = "vocab.txt"
CONFIG_FILE = "config.json"
MODEL_FILE = "pytorch_model.bin"
BASE_URL = "https://huggingface.co/{}/resolve/main/{}"

headers = {'user-agent': 'transformers/4.8.2; python/3.8.5;  \
			session_id/{}; torch/1.9.0; tensorflow/2.5.0; \
			file_type/model; framework/pytorch; from_auto_class/False'.format(SESSIONID)}

model_id = "bert-base-chinese"

# 创建模型对应的文件夹

model_dir = model_id.replace("/", "-")

if not os.path.exists(model_dir):
	os.mkdir(model_dir)

# vocab 和 config 文件可以直接下载

r = requests.get(BASE_URL.format(model_id, VOCAB_FILE), headers=headers)
r.encoding = "utf-8"
with open(os.path.join(model_dir, VOCAB_FILE), "w", encoding="utf-8") as f:
	f.write(r.text)
	print("{}词典文件下载完毕!".format(model_id))

r = requests.get(BASE_URL.format(model_id, CONFIG_FILE), headers=headers)
r.encoding = "utf-8"
with open(os.path.join(model_dir, CONFIG_FILE), "w", encoding="utf-8") as f:
	json.dump(r.json(), f, indent="\t")
	print("{}配置文件下载完毕!".format(model_id))

# 模型文件需要分两步进行

# Step1 获取模型下载的真实地址
r = requests.head(BASE_URL.format(model_id, MODEL_FILE), headers=headers)
r.raise_for_status()
if 300 <= r.status_code <= 399:
	url_to_download = r.headers["Location"]

# Step2 请求真实地址下载模型
r = requests.get(url_to_download, stream=True, proxies=None, headers=None)
r.raise_for_status()

# 这里的进度条是可选项,直接使用了transformers包中的代码
content_length = r.headers.get("Content-Length")
total = int(content_length) if content_length is not None else None
progress = tqdm(
	unit="B",
	unit_scale=True,
	total=total,
	initial=0,
	desc="Downloading Model",
)

with open(os.path.join(model_dir, MODEL_FILE), "wb") as temp_file:
	for chunk in r.iter_content(chunk_size=1024):
		if chunk:  # filter out keep-alive new chunks
			progress.update(len(chunk))
			temp_file.write(chunk)

progress.close()

print("{}模型文件下载完毕!".format(model_id))

Git LFS 下载

准备工作

Git LFS的方案相较于前面自行实现的方案要简洁的多得多。我们需要在安装git的基础上,再安装git lfs。以Windows为例,命令如下

git lfs install
模型下载

我们还是以bert-base-chinese为例进行下载,打开具体的模型面,可以看到右上角有一个Use in Transformersbutton

点击该Button,我们就可以看到具体的下载命令了。

拷贝命令在终端执行,就可以下载了。下载后的格式,和前面自行实现的代码是一样,但是就使用体验上来看,这种方式明显会更加优雅!

但是,这种方案也存在着一定的问题,即会下载仓库中的所有文件,会大大延长模型下载的时间。我们可以看到在目录中包含着flax_model.msgpacktf_model.h5pytorch_model.bin三个不同框架模型文件,在bert-base-uncased的版本中,还存在着rust版本rust_model.ot模型,如果我们只想要一个版本的模型文件,这种方案就无法实现了。

如果想实现模型精确下载,我们还可以借助Hugging Face Hub,下面来介绍这种方案。







Sanfor

模型下载的几种方式

  • 6
    点赞
  • 14
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

坦笑&&life

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值