【BERT for Tensorflow】本地ckpt文件的BERT使用

本地ckpt文件的BERT使用


摘要

本文你将学到:

  • 如何将官方ckpt文件转为pytorch.bin以供pytorch/tensorflow使用
  • 如何在BERT的基础上拼接模型解决下游任务

BERT官方ckpt文件

首先,下载好BERT官方文件,如uncased_L-12_H-768_A-12

使用如下文件代码convert_bert_original_tf_checkpoint_to_pytorch.py

# coding=utf-8
# Copyright 2018 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Convert BERT checkpoint."""


import argparse

import torch

from transformers import BertConfig, BertForPreTraining, load_tf_weights_in_bert
from transformers.utils import logging


logging.set_verbosity_info()


def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytorch_dump_path):
    # Initialise PyTorch model
    config = BertConfig.from_json_file(bert_config_file)
    print("Building PyTorch model from configuration: {}".format(str(config)))
    model = BertForPreTraining(config)

    # Load weights from tf checkpoint
    load_tf_weights_in_bert(model, config, tf_checkpoint_path)

    # Save pytorch-model
    print("Save PyTorch model to {}".format(pytorch_dump_path))
    torch.save(model.state_dict(), pytorch_dump_path)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    # Required parameters
    parser.add_argument(
        "--tf_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path."
    )
    parser.add_argument(
        "--bert_config_file",
        default=None,
        type=str,
        required=True,
        help="The config json file corresponding to the pre-trained BERT model. \n"
        "This specifies the model architecture.",
    )
    parser.add_argument(
        "--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model."
    )
    args = parser.parse_args()
    convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path, args.bert_config_file, args.pytorch_dump_path)

ckpt to bin

之后,在命令行中输入

python convert_bert_original_tf_checkpoint_to_pytorch.py --tf_checkpoint_path bert_model.ckpt.index --bert_config_file bert_config.json  --pytorch_dump_path  pytorch_model.bin

注意,在Windows命令行中小心\换行符的不匹配而出问题,故使用上面一行更安全,用空格代替\换行符

python convert_bert_original_tf_checkpoint_to_pytorch.py \
--tf_checkpoint_path Models/chinese_L-12_H-768_A-12/bert_model.ckpt.index \
--bert_config_file Models/chinese_L-12_H-768_A-12/bert_config.json  \
--pytorch_dump_path  Models/chinese_L-12_H-768_A-12/pytorch_model.bin

之后,你就可以得到pytorch_model.bin,将这个文件复制到ckpt文件夹
在这里插入图片描述

Tensorflow Fine-Tune(or whatever:)

最后,就可以通过Tensorflow加载使用

加载部分:

from transformers import BertConfig,TFBertModel
import os

pretrained_path = "../input/uncased_L-12_H-768_A-12/"
config_path = os.path.join(pretrained_path,"bert_config.json")
checkpoint_path = os.path.join(pretrained_path,"bert_model.ckpt")
vocab_path = os.path.join(pretrained_path,'vocab.txt')
 

# 加载config
config = BertConfig.from_json_file(config_path)
# 加载原始模型
tfbert_model1 = TFBertModel.from_pretrained(pretrained_path,from_pt=True, config=config)

# # 加载分类模型
# tfbert_model2 = TFBertForSequenceClassification.from_pretrained(pretrained_path, from_pt=True, config=config)

自定义模型部分
本文使用BERT + Bi-LSTM实现文本三分类任务

    # Encoded token ids from BERT tokenizer.
    input_ids = tf.keras.layers.Input(
        shape=(max_length,), dtype=tf.int32, name="input_ids"
    )
    # Attention masks indicates to the model which tokens should be attended to.
    attention_masks = tf.keras.layers.Input(
        shape=(max_length,), dtype=tf.int32, name="attention_masks"
    )
    # Token type ids are binary masks identifying different sequences in the model.
    token_type_ids = tf.keras.layers.Input(
        shape=(max_length,), dtype=tf.int32, name="token_type_ids"
    )
    # Loading pretrained BERT model.
    bert_model = TFBertModel.from_pretrained(pretrained_path,from_pt=True, config=config)
    
    # Freeze the BERT model to reuse the pretrained features without modifying them.
    bert_model.trainable = False
	'''从这开始,自定义!!!'''
    sequence_output, pooled_output = bert_model(
        input_ids, attention_mask=attention_masks, token_type_ids=token_type_ids
    )
    # Add trainable layers on top of frozen layers to adapt the pretrained features on the new data.
    bi_lstm = tf.keras.layers.Bidirectional(
        tf.keras.layers.LSTM(64, return_sequences=True)
    )(sequence_output)
    # Applying hybrid pooling approach to bi_lstm sequence output.
    avg_pool = tf.keras.layers.GlobalAveragePooling1D()(bi_lstm)
    max_pool = tf.keras.layers.GlobalMaxPooling1D()(bi_lstm)
    concat = tf.keras.layers.concatenate([avg_pool, max_pool])
    dropout = tf.keras.layers.Dropout(0.3)(concat)
    output = tf.keras.layers.Dense(3, activation="softmax")(dropout)
    model = tf.keras.models.Model(
        inputs=[input_ids, attention_masks, token_type_ids], outputs=output
    )

    model.compile(
        optimizer=tf.keras.optimizers.Adam(),
        loss="categorical_crossentropy",
        metrics=["acc"],
    )

在这里插入图片描述
参照上述代码'''从这开始,自定义!!!'''之后,自由拼接各种下游任务模型,最后通过model.summary()查看模型组成

  • 6
    点赞
  • 29
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 3
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

阿芒Aris

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值