tensorflow 加载bert_如何使用来自 TensorFlow Hub 的 BERT 模型

本博客展示了如何在TensorFlow中从TensorFlow Hub加载预训练的BERT模型,包括MNLI、SQuAD和PubMed任务的模型。内容涵盖从原始文本进行分词、转换为ID,以及使用模型生成池化和序列输出。还讨论了如何计算句子间嵌入向量的语义相似度,并提供了GPU运行的建议。最后,提供了更多BERT模型资源和微调教程的链接。
摘要由CSDN通过智能技术生成
bfc4a4a5196aa4737172ad300757ffc6.gif

Colab Notebookhttps://colab.research.google.com/github/tensorflow/hub/blob/master/examples/colab/bert_experts.ipynb

GitHub Codehttps://github.com/tensorflow/hub/blob/master/examples/colab/bert_experts.ipynb

? 下载 Notebookhttps://storage.googleapis.com/tensorflow_docs/hub/examples/colab/bert_experts.ipynb

查看 TFHub 模型https://hub.tensorflow.google.cn/s?q=experts%2Fbert

上周,我们介绍了 TensorFlow Hub 中提供了丰富多样的 BERT 模型以及类 BERT 模型。今天我们将通过  Colab 演示如何执行以下操作:
  • 从 TensorFlow Hub 加载已在不同任务上训练的 BERT 模型,包括 MNLI、SQuAD 和 PubMed
  • TensorFlow Hubhub.tensorflow.google.cn
  • 使用匹配的预处理模型对原始文本进行分词 (Tokenize) 并将其转换成 ID
  • 使用加载的模型从 Token 输入 ID 生成池化和序列输出
  • 查看不同句子池化输出的语义相似度

注:应使用 GPU 运行时运行此 Colab

设置和导入

pip3 install --quiet tensorflow
pip3 install --quiet tensorflow_text
import seaborn as sns
from sklearn.metrics import pairwise

import tensorflow as tf
import tensorflow_hub as hub
import tensorflow_text as text # Imports TF ops for preprocessing.
配置模型
BERT_MODEL = "https://hub.tensorflow.google.cn/google/experts/bert/wiki_books/2" # @param {type: "string"} ["https://hub.tensorflow.google.cn/google/experts/bert/wiki_books/2", "https://hub.tensorflow.google.cn/google/experts/bert/wiki_books/mnli/2", "https://hub.tensorflow.google.cn/google/experts/bert/wiki_books/qnli/2", "https://hub.tensorflow.google.cn/google/experts/bert/wiki_books/qqp/2", "https://hub.tensorflow.google.cn/google/experts/bert/wiki_books/squad2/2", "https://hub.tensorflow.google.cn/google/experts/bert/wiki_books/sst2/2",  "https://hub.tensorflow.google.cn/google/experts/bert/pubmed/2", "https://tfhub.dev/google/experts/bert/pubmed/squad2/2"]
# Preprocessing must match the model, but all the above use the same.
PREPROCESS_MODEL = "https://hub.tensorflow.google.cn/tensorflow/bert_en_uncased_preprocess/1"

句子

我们从 Wikipedia 中获取一些要通过模型运行的句子:

sentences = [
"Here We Go Then, You And I is a 1999 album by Norwegian pop artist Morten Abel. It was Abel's second CD as a solo artist.",
"The album went straight to number one on the Norwegian album chart, and sold to double platinum.",
"Among the singles released from the album were the songs \"Be My Lover\" and \"Hard To Stay Awake\".",
"Riccardo Zegna is an Italian jazz musician.",
"Rajko Maksimović is a composer, writer, and music pedagogue.",
"One of the most significant Serbian composers of our time, Maksimović has been and remains active in creating works for different ensembles.",
"Ceylon spinach is a common name for several plants and may refer to: Basella alba Talinum fruticosum",
"A solar eclipse occurs when the Moon passes between Earth and the Sun, thereby totally or partly obscuring the image of the Sun for a viewer on Earth.",
"A partial solar eclipse occurs in the polar regions of the Earth when the center of the Moon's shadow misses the Earth.",
]

运行模型

我们将从 TF-Hub 中加载 BERT 模型,使用 TF-Hub 中匹配的预处理模型对我们的句子进行分词 (Tokenize),然后将分词后的句子馈入模型。为了确保此 Colab 快速简单,建议在 GPU 上运行。

转至 Runtime Change runtime type 以确保选择 GPU

preprocess = hub.load(PREPROCESS_MODEL)
bert = hub.load(BERT_MODEL)
inputs = preprocess(sentences)
outputs = bert(inputs)
print("Sentences:")
print(sentences)

print("\nBERT inputs:")
print(inputs)

print("\nPooled embeddings:")
print(outputs["pooled_output"])

print("\nPer token embeddings:")
print(outputs["sequence_output"])
Sentences:
["Here We Go Then, You And I is a 1999 album by Norwegian pop artist Morten Abel. It was Abel's second CD as a solo artist.", 'The album went straight to number one on the Norwegian album chart, and sold to double platinum.', 'Among the singles released from the album were the songs "Be My Lover" and "Hard To Stay Awake".', 'Riccardo Zegna is an Italian jazz musician.', 'Rajko Maksimović is a composer, writer, and music pedagogue.', 'One of the most significant Serbian composers of our time, Maksimović has been and remains active in creating works for different ensembles.', 'Ceylon spinach is a common name for several plants and may refer to: Basella alba Talinum fruticosum', 'A solar eclipse occurs when the Moon passes between Earth and the Sun, thereby totally or partly obscuring the image of the Sun for a viewer on Earth.', "A partial solar eclipse occurs in the polar regions of the Earth when the center of the Moon's shadow misses the Earth."]

BERT inputs:
{'input_type_ids': array([[0, 0, 0, ..., 0, 0, 0],
[0, 0, 0, ..., 0, 0, 0],
[0, 0, 0, ..., 0, 0, 0],
...,
[0, 0, 0, ..., 0, 0, 0],
[0, 0, 0, ..., 0, 0, 0],
[0, 0, 0, ..., 0, 0, 0]], dtype=int32)>, 'input_mask': array([[1, 1, 1, ..., 0, 0, 0],
[1, 1, 1, ..., 0, 0, 0],
[1, 1, 1, ..., 0, 0, 0],
...,
[1, 1, 1, ..., 0, 0, 0],
[1, 1, 1, ..., 0, 0, 0],
[1, 1, 1, ..., 0, 0, 0]], dtype=int32)>, 'input_word_ids': array([[ 101, 2182, 2057, ..., 0, 0, 0],
[ 101, 1996, 2201, ..., 0, 0, 0],
[ 101, 2426, 1996, ..., 0, 0, 0],
...,
[ 101, 16447, 6714, ..., 0, 0, 0],
[ 101, 1037, 5943, ..., 0, 0, 0],
[ 101, 1037, 7704, ..., 0, 0, 0]], dtype=int32)>}

Pooled embeddings:
tf.Tensor(
[[ 0.79759794 -0.48580435 0.49781656 ... -0.34488496 0.39727688
-0.20639414]
[ 0.57120484 -0.41205186 0.70489156 ... -0.35185218 0.19032398
-0.4041889 ]
[-0.6993836 0.1586663 0.06569844 ... -0.06232387 -0.8155013
-0.07923748]
...
[-0.35727036 0.77089816 0.15756643 ... 0.441857 -0.8644817
0.04504787]
[ 0.9107702 0.41501534 0.5606339 ... -0.49263883 0.3964067
-0.05036191]
[ 0.90502924 -0.15505327 0.726722 ... -0.34734532 0.50526506
-0.19542982]], shape=(9, 768), dtype=float32)

Per token embeddings:
tf.Tensor(
[[[ 1.09197533e+00 -5.30553877e-01 5.46399117e-01 ... -3.59626472e-01
4.20411289e-01 -2.09402084e-01]
[ 1.01438284e+00 7.80790329e-01 8.53758693e-01 ... 5.52820444e-01
-1.12457883e+00 5.60277641e-01]
[ 7.88627684e-01 7.77753443e-02 9.51507747e-01 ... -1.90755337e-01
5.92060506e-01 6.19107723e-01]
...
[-3.22031736e-01 -4.25212324e-01 -1.28237933e-01 ... -3.90951157e-01
-7.90973544e-01 4.22365129e-01]
[-3.10389847e-02 2.39855915e-01 -2.19942629e-01 ... -1.14405245e-01
-1.26804781e+00 -1.61363974e-01]
[-4.20636892e-01 5.49730241e-01 -3.24446023e-01 ... -1.84789032e-01
-1.13429689e+00 -5.89773059e-02]]

[[ 6.49309337e-01 -4.38080192e-01 8.76956999e-01 ... -3.67556065e-01
1.92673296e-01 -4.28645700e-01]
[-1.12487435e+00 2.99313068e-01 1.17996347e+00 ... 4.87294406e-01
5.34003854e-01 2.28363827e-01]
[-2.70572990e-01 3.23538631e-02 1.04257035e+00 ... 5.89937270e-01
1.53678954e+00 5.84256709e-01]
...
[-1.47624981e+00 1.82391271e-01 5.58804125e-02 ... -1.67332077e+00
-6.73984885e-01 -7.24499583e-01]
[-1.51381290e+00 5.81846952e-01 1.61421359e-01 ... -1.26408398e+00
-4.02721316e-01 -9.71973777e-01]
[-4.71531510e-01 2.28173390e-01 5.27765870e-01 ... -7.54838765e-01
-9.09029484e-01 -1.69548154e-01]]

[[-8.66093040e-01 1.60018250e-01 6.57932162e-02 ... -6.24047518e-02
-1.14323711e+00 -7.94039369e-02]
[ 7.71180928e-01 7.08045244e-01 1.13499165e-01 ... 7.88309634e-01
-3.14380586e-01 -9.74871933e-01]
[-4.40023899e-01 -3.00594330e-01 3.54794949e-01 ... 7.97353014e-02
-4.73935485e-01 -1.10018420e+00]
...
[-1.02053010e+00 2.69383639e-01 -4.73101676e-01 ... -6.63193762e-01
-1.45799184e+00 -3.46655250e-01]
[-9.70034838e-01 -4.50136065e-02 -5.97798169e-01 ... -3.05265576e-01
-1.27442575e+00 -2.80517340e-01]
[-7.31442988e-01 1.76993430e-01 -4.62578893e-01 ... -1.60623401e-01
-1.63460755e+00 -3.20607185e-01]]

...

[[-3.73753369e-01 1.02253771e+00 1.58890173e-01 ... 4.74535972e-01
-1.31081581e+00 4.50783782e-02]
[-4.15891230e-01 5.00191450e-01 -4.58438754e-01 ... 4.14822072e-01
-6.20658875e-01 -7.15549171e-01]
[-1.25043917e+00 5.09365320e-01 -5.71037054e-01 ... 3.54916602e-01
2.43683696e-01 -2.05771995e+00]
...
[ 1.33936703e-01 1.18591738e+00 -2.21700743e-01 ... -8.19471061e-01
-1.67373013e+00 -3.96926820e-01]
[-3.36624265e-01 1.65562105e+00 -3.78126293e-01 ... -9.67453301e-01
-1.48010290e+00 -8.33311737e-01]
[-2.26493448e-01 1.61784422e+00 -6.70443296e-01 ... -4.90783423e-01
-1.45356917e+00 -7.17075229e-01]]

[[ 1.53202307e+00 4.41654980e-01 6.33757174e-01 ... -5.39538860e-01
4.19378459e-01 -5.04045524e-02]
[ 8.93778205e-01 8.93955052e-01 3.06287408e-02 ... 5.90391904e-02
-2.06495613e-01 -8.48110974e-01]
[-1.85600221e-02 1.04790771e+00 -1.33295977e+00 ... -1.38697088e-01
-3.78795475e-01 -4.90686238e-01]
...
[ 1.42756522e+00 1.06969848e-01 -4.06335592e-02 ... -3.17773186e-02
-4.14598197e-01 7.00368583e-01]
[ 1.12866342e+00 1.45478487e-01 -6.13721192e-01 ... 4.74921733e-01
-3.98516655e-01 4.31243867e-01]
[ 1.43932939e+00 1.80306956e-01 -4.28539753e-01 ... -2.50225902e-01
-1.00005007e+00 3.59855264e-01]]

[[ 1.49934173e+00 -1.56314075e-01 9.21745181e-01 ... -3.62421691e-01
5.56351066e-01 -1.97976440e-01]
[ 1.11105371e+00 3.66513431e-01 3.55058551e-01 ... -5.42975247e-01
1.44716531e-01 -3.16758066e-01]
[ 2.40487278e-01 3.81156325e-01 -5.91827273e-01 ... 3.74107122e-01
-5.98296165e-01 -1.01662648e+00]
...
[ 1.01586223e+00 5.02603769e-01 1.07373089e-01 ... -9.56426382e-01
-4.10394996e-01 -2.67601997e-01]
[ 1.18489289e+00 6.54797733e-01 1.01688504e-03 ... -8.61546934e-01
-8.80392492e-02 -3.06370854e-01]
[ 1.26691115e+00 4.77678716e-01 6.62857294e-03 ... -1.15858066e+00
-7.06758797e-02 -1.86787039e-01]]], shape=(9, 128, 768), dtype=float32)

语义相似度

现在,我们来看一下句子的 pooled_output 嵌入向量并比较它们在句子间的相似度。

def plot_similarity(features, labels):
"""Plot a similarity matrix of the embeddings."""
cos_sim = pairwise.cosine_similarity(features)
sns.set(font_scale=1.2)
cbar_kws=dict(use_gridspec=False, location="left")
g = sns.heatmap(
cos_sim, xticklabels=labels, yticklabels=labels,
vmin=0, vmax=1, cmap="Blues", cbar_kws=cbar_kws)
g.tick_params(labelright=True, labelleft=False)
g.set_yticklabels(labels, rotation=0)
g.set_title("Semantic Textual Similarity")


plot_similarity(outputs["pooled_output"], sentences)
fa9ccc98737df2a407447cbbf085e96f.png

了解详情

  • 要查找更多 BERT 模型,请访问 TensorFlow Hub

  • TensorFlow Hubhub.tensorflow.google.cn

  • 此笔记本演示了如何使用 BERT 进行简单推理,您可以在此处找到有关微调 BERT 的更高级教程
  • 此处tensorflow.google.cn/official_models/fine_tuning_bert

  • 我们仅使用了一个 GPU 芯片来运行模型,您可以在 tensorflow.google.cn/tutorials/distribute/save_and_load 下详细了解如何使用 tf.distribute 加载模型

了解更多请点击 “

02b9d8b8939c4c79e30c4f5ddda269cf.png
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值