tf2 BERT模型文件保存和加载

探索一下BERT模型保存和加载方式,基于源码。

保存

所谓“保存模型”一般是指保存ckpt和saved_model两种格式的模型。
ckpt方式与session.run模型下保存模型格式一样(在sess.run模式下,通常使用saver = tf.train.Saver()和saver.save()保存模型),这种模型文件需要原始模型代码才能运行,一般用于训练中保存/加载权重。
saved_model格式是一种轻量化的模型,不仅包含权重值,还包含计算。它不需要原始模型构建代码就可以运行,因此,对共享和部署(使用 TFLite、TensorFlow.js、TensorFlow Serving 或 TensorFlow Hub)非常有用。

ckpt方式下一共会保存4个文件:

model.ckpt-xxxxx.data-00000-of-00001: 保存当前参数值。比如网络的权值,偏置,操作等等。
model.ckpt.index :保存当前参数名。二进制或者其他格式,不可直接查看 。
model.ckpt.meta:某个ckpt的meta数据 二进制 或者其他格式,不可直接查看,保存了TensorFlow计算图的结构信息。
checkpoint:文本文件,记录了保存的最新的checkpoint文件以及其它checkpoint文件列表。


1.默认checkpoint的保存行为
每10分钟(600 秒)写入一个checkpoint,还会在train方法开始(第一次迭代)和完成(最后一次迭代)时写入一个checkpoint;只在目录中保留5个最近写入的checkpoint; 

2.修改默认checkpoint的保存行为
创建一个RunConfig对象来定义所需的时间安排;在实例化Estimator时,将该RunConfig对象传递给Estimator的config参数;

my_checkpointing_config = tf.estimator.RunConfig(
    save_checkpoints_secs = 20*60,  # 每20分钟保存一次checkpoint
    keep_checkpoint_max = 10,       # 保存10个最近的checkpoints
)

classifier = tf.estimator.DNNClassifier(
    feature_columns=my_feature_columns,
    hidden_units=[10, 10],
    n_classes=3,
    model_dir='models/iris',
    config=my_checkpointing_config)

加载

1.Estimator 通过运行 model_fn() 构建模型图。

2.Estimator 根据最近写入的检查点中存储的数据来初始化新模型的权重。

一旦存在检查点,TensorFlow 就会在您每次调用 train()、evaluate() 或 predict() 时加载模型。

  • 0
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
对于BERT模型,您可以使用Hugging Face的`transformers`库来加载和使用。下面是一个示例代码: ```python from pyspark.sql import SparkSession from transformers import BertTokenizer, TFBertModel # 创建SparkSession spark = SparkSession.builder \ .appName("BERT Model Inference") \ .getOrCreate() # 加载BERT模型和tokenizer model_name = "bert-base-uncased" model = TFBertModel.from_pretrained(model_name) tokenizer = BertTokenizer.from_pretrained(model_name) # 加载测试数据 test_data = spark.read.format("csv").option("header", "true").load("path/to/your/test_data.csv") # 定义预处理函数 def preprocess_text(text): encoded_input = tokenizer(text, padding=True, truncation=True, max_length=128, return_tensors="tf") return encoded_input # 定义UDF以进行预处理 preprocess_udf = spark.udf.register("preprocess_text", preprocess_text) # 对测试数据进行预处理 preprocessed_data = test_data.withColumn("input", preprocess_udf(test_data["text"])) # 定义UDF以进行推断 infer_udf = spark.udf.register("infer", lambda x: model(x)["logits"].numpy().tolist()) # 进行推断 predictions = preprocessed_data.withColumn("prediction", infer_udf(preprocessed_data["input"])) # 显示预测结果 predictions.show() ``` 在上述代码中,您需要将`path/to/your/test_data.csv`替换为您的测试数据文件路径。您还可以根据需要调整模型名称和预处理选项,例如最大长度和填充方式。 请注意,此代码假设您已经安装了`transformers`库和其依赖项。如果还没有安装,可以使用以下命令进行安装: ``` pip install transformers ```
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值