【从官方案例学框架Tensorflow/Keras】基于BERT解决SQuAD文本抽取任务
注:本系列仅帮助大家快速理解、学习并能独立使用相关框架进行深度学习的研究,理论部分还请自行学习补充,每个框架的官方经典案例写的都非常好,很值得进行学习使用。可以说在完全理解官方经典案例后加以修改便可以解决大多数常见的相关任务。
摘要:基于BERT解决SQuAD文本抽取QA任务,注意理解解决这个QA任务的方法和代码的实现
目录
1 Introduction
本例中将使用SQuAD(Stanford Question-Answering Dataset),斯坦福QA数据集。在SQuAD中,输入是由一个问题、一段文本组成。目标是在这段文本中找出问题的答案(答案是原文本的一段连续tokens)。我们将用"Exact Match"完全匹配的评估方式来进行评测,这衡量了准确匹配任何一个真实答案的预测的百分比。
使用BERT微调解决这项任务的方法如下:
- 将文本内容、问题当作BERT的输入
- 创造两个形状与BERT隐藏层形状相同的向量,S、T;
- 计算每个token作为答案片段的开始和结尾的概率。开始的概率是通过token与S的点积并Softmax得到,结尾的概率是通过token与T的点积并Softmax得到
- 在训练S、T向量过程中也微调BERT模型
2 Setup
import os
import re
import json
import string
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tokenizers import BertWordPieceTokenizer
from transformers import BertTokenizer, TFBertModel, BertConfig
max_len = 384
configuration = BertConfig() # default parameters and configuration for BERT
3 Set-up BERT tokenizer
# Save the slow pretrained tokenizer
slow_tokenizer = BertTokenizer.from_pretrained("./input/uncased_L-12_H-768_A-12/")
save_path = "Text Extraction with BERT/"
if not os.path.exists(save_path):
os.makedirs(save_path)
slow_tokenizer.save_pretrained(save_path)
# Load the fast tokenizer from saved file
tokenizer = BertWordPieceTokenizer("./input/uncased_L-12_H-768_A-12/vocab.txt", lowercase=True)
4 Load the data
train_data_url = "https://rajpurkar.github.io/SQuAD-explorer/dataset/train-v1.1.json"
train_path = keras.utils.get_file("train.json", train_data_url)
eval_data_url = "https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v1.1.json"
eval_path = keras.utils.get_file("eval.json", eval_data_url)
5 Preprocess the data
- 处理JSON文件、并将每一条记录为
SquadExample
,如下代码虽长,却均为数据处理,不做阐述 - 处理
SquadExample
,提取 x_train, y_train, x_eval, y_eval
class SquadExample:
def __init__(self, question, context, start_char_idx, answer_text, all_answers):
self.question = question
self.context = context
self.start_char_idx = start_char_idx
self.answer_text = answer_text
self.all_answers = all_answers
self.skip = False
def preprocess(self):
context = self.context
question = self.question
answer_text = self.answer_text
start_char_idx = self.start_char_idx
# Clean context, answer and question
context = " ".join(str(context).split())
question = " ".join(str(question).split())
answer = " ".join(str(answer_text).split())
# Find end character index of answer in context
end_char_idx = start_char_idx + len(answer)
if end_char_idx >= len(context):
self.skip = True
return
# Mark the character indexes in context that are in answer
is_char_in_ans = [0] * len(context)
for idx in range(start_char_idx, end_char_idx):
is_char_in_ans[idx] = 1
# Tokenize context
tokenized_context = tokenizer.encode(context)
# Find tokens that were created from answer characters
ans_token_idx = []
for idx, (start, end) in enumerate(tokenized_context.offsets):
if sum(is_char_in_ans[start:end]) > 0:
ans_token_idx.append(idx)
if len(ans_token_idx) == 0:
self.skip = True
return
# Find start and end token index for tokens from answer
start_token_idx = ans_token_idx[0]
end_token_idx = ans_token_idx[-1]
# Tokenize question
tokenized_question = tokenizer.encode(question)
# Create inputs
input_ids = tokenized_context.ids + tokenized_question.ids[1:]
token_type_ids = [0] * len(tokenized_context.ids) + [1] * len(
tokenized_question.ids[1:]
)
attention_mask = [1] * len(input_ids)
# Pad and create attention masks.
# Skip if truncation is needed
padding_length = max_len - len(input_ids)
if padding_length > 0: # pad
input_ids = input_ids + ([0] * padding_length)
attention_mask = attention_mask + ([0] * padding_length)
token_type_ids = token_type_ids + ([0] * padding_length)
elif padding_length < 0: # skip
self.skip = True
return
self.input_ids = input_ids
self.token_type_ids = token_type_ids
self.attention_mask = attention_mask
self.start_token_idx = start_token_idx
self.end_token_idx = end_token_idx
self.context_token_to_char = tokenized_context.offsets
with open(train_path) as f:
raw_train_data = json.load(f)
with open(eval_path) as f:
raw_eval_data = json.load(f)
def create_squad_examples(raw_data):
squad_examples = []
for item in raw_data["data"]:
for para in item["paragraphs"]:
context = para["context"]
for qa in para["qas"]:
question = qa["question"]
answer_text = qa["answers"][0]["text"]
all_answers = [_["text"] for _ in qa["answers"]]
start_char_idx = qa["answers"][0]["answer_start"]
squad_eg = SquadExample(
question, context, start_char_idx, answer_text, all_answers
)
squad_eg.preprocess()
squad_examples.append(squad_eg)
return squad_examples
def create_inputs_targets(squad_examples):
dataset_dict = {
"input_ids": [],
"token_type_ids": [],
"attention_mask": [],
"start_token_idx": [],
"end_token_idx": [],
}
for item in squad_examples:
if item.skip == False:
for key in dataset_dict:
dataset_dict[key].append(getattr(item, key))
for key in dataset_dict:
dataset_dict[key] = np.array(dataset_dict[key])
x = [
dataset_dict["input_ids"],
dataset_dict["token_type_ids"],
dataset_dict["attention_mask"],
]
y = [dataset_dict["start_token_idx"], dataset_dict["end_token_idx"]]
return x, y
train_squad_examples = create_squad_examples(raw_train_data)
x_train, y_train = create_inputs_targets(train_squad_examples)
print(f"{len(train_squad_examples)} training points created.")
eval_squad_examples = create_squad_examples(raw_eval_data)
x_eval, y_eval = create_inputs_targets(eval_squad_examples)
print(f"{len(eval_squad_examples)} evaluation points created.")
6 Create the Question-Answering Model using BERT and Functional API
使用BERT完成QA模型,有关BERT如何下载官方ckpt文件,导入本地使用请参照
本地ckpt文件的BERT使用
下面代码中值得注意的是start_logits
和end_logits
,这便是我们最开始说的S和T,注意两者独立,故也是有两个相同的损失函数
from transformers import BertConfig,TFBertModel
import os
pretrained_path = "./input/uncased_L-12_H-768_A-12/"
config_path = os.path.join(pretrained_path,"bert_config.json")
checkpoint_path = os.path.join(pretrained_path,"bert_model.ckpt")
vocab_path = os.path.join(pretrained_path,'vocab.txt')
def create_model():
## BERT encoder
# encoder = TFBertModel.from_pretrained("./input/uncased_L-12_H-768_A-12/")
encoder = TFBertModel.from_pretrained(pretrained_path,from_pt=True, config=config)
## QA Model
input_ids = layers.Input(shape=(max_len,), dtype=tf.int32)
token_type_ids = layers.Input(shape=(max_len,), dtype=tf.int32)
attention_mask = layers.Input(shape=(max_len,), dtype=tf.int32)
embedding = encoder(
input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask
)[0]
start_logits = layers.Dense(1, name="start_logit", use_bias=False)(embedding)
start_logits = layers.Flatten()(start_logits)
end_logits = layers.Dense(1, name="end_logit", use_bias=False)(embedding)
end_logits = layers.Flatten()(end_logits)
start_probs = layers.Activation(keras.activations.softmax)(start_logits)
end_probs = layers.Activation(keras.activations.softmax)(end_logits)
model = keras.Model(
inputs=[input_ids, token_type_ids, attention_mask],
outputs=[start_probs, end_probs],
)
loss = keras.losses.SparseCategoricalCrossentropy(from_logits=False)
optimizer = keras.optimizers.Adam(lr=5e-5)
model.compile(optimizer=optimizer, loss=[loss, loss])
return model
谷歌云TPU的适配,当然本地将use_tpu = False即可不使用TPU
use_tpu = False
if use_tpu:
# Create distribution strategy
tpu = tf.distribute.cluster_resolver.TPUClusterResolver()
tf.config.experimental_connect_to_cluster(tpu)
tf.tpu.experimental.initialize_tpu_system(tpu)
strategy = tf.distribute.experimental.TPUStrategy(tpu)
# Create model
with strategy.scope():
model = create_model()
else:
model = create_model()
model.summary()
7 Create evaluation Callback
构造回调函数;回调函数会在每个epoch结束后通过验证数据计算出当前模型效果
def normalize_text(text):
text = text.lower()
# Remove punctuations
exclude = set(string.punctuation)
text = "".join(ch for ch in text if ch not in exclude)
# Remove articles
regex = re.compile(r"\b(a|an|the)\b", re.UNICODE)
text = re.sub(regex, " ", text)
# Remove extra white space
text = " ".join(text.split())
return text
class ExactMatch(keras.callbacks.Callback):
"""
Each `SquadExample` object contains the character level offsets for each token
in its input paragraph. We use them to get back the span of text corresponding
to the tokens between our predicted start and end tokens.
All the ground-truth answers are also present in each `SquadExample` object.
We calculate the percentage of data points where the span of text obtained
from model predictions matches one of the ground-truth answers.
"""
def __init__(self, x_eval, y_eval):
self.x_eval = x_eval
self.y_eval = y_eval
def on_epoch_end(self, epoch, logs=None):
pred_start, pred_end = self.model.predict(self.x_eval)
count = 0
eval_examples_no_skip = [_ for _ in eval_squad_examples if _.skip == False]
for idx, (start, end) in enumerate(zip(pred_start, pred_end)):
squad_eg = eval_examples_no_skip[idx]
offsets = squad_eg.context_token_to_char
start = np.argmax(start)
end = np.argmax(end)
if start >= len(offsets):
continue
pred_char_start = offsets[start][0]
if end < len(offsets):
pred_char_end = offsets[end][1]
pred_ans = squad_eg.context[pred_char_start:pred_char_end]
else:
pred_ans = squad_eg.context[pred_char_start:]
normalized_pred_ans = normalize_text(pred_ans)
normalized_true_ans = [normalize_text(_) for _ in squad_eg.all_answers]
if normalized_pred_ans in normalized_true_ans:
count += 1
acc = count / len(self.y_eval[0])
print(f"\nepoch={epoch+1}, exact match score={acc:.2f}")
8 Train and Evaluate
exact_match_callback = ExactMatch(x_eval, y_eval)
model.fit(
x_train[:1000],
y_train[:1000],
epochs=1, # For demonstration, 3 epochs are recommended
verbose=2,
batch_size=1,
callbacks=[exact_match_callback],
)
9 Summary
基于BERT解决SQuAD文本抽取任务的完整代码如下
在本段代码中,请重点理解掌握如何解决这种SQuAD文本抽取式QA任务,虽然想法直接简单,但其实实现上也还是蛮简单的:},可以说理解看过一次就能解决有这类问题的baseline方法
import os
import re
import json
import string
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tokenizers import BertWordPieceTokenizer
from transformers import BertTokenizer, TFBertModel, BertConfig
'''预训练模型本地地址导入'''
max_len = 384
pretrained_path = "./input/uncased_L-12_H-768_A-12/"
config_path = os.path.join(pretrained_path,"bert_config.json")
# 加载config
config = BertConfig.from_json_file(config_path)
checkpoint_path = os.path.join(pretrained_path,"bert_model.ckpt")
vocab_path = os.path.join(pretrained_path,'vocab.txt')
'''tokenizer的本地保存'''
# Save the slow pretrained tokenizer
slow_tokenizer = BertTokenizer.from_pretrained("./input/uncased_L-12_H-768_A-12/")
save_path = "Text Extraction with BERT/"
if not os.path.exists(save_path):
os.makedirs(save_path)
slow_tokenizer.save_pretrained(save_path)
# Load the fast tokenizer from saved file
tokenizer = BertWordPieceTokenizer("./input/uncased_L-12_H-768_A-12/vocab.txt", lowercase=True)
'''SQuAD数据集的读取与预处理'''
train_data_url = "https://rajpurkar.github.io/SQuAD-explorer/dataset/train-v1.1.json"
train_path = keras.utils.get_file("train.json", train_data_url)
eval_data_url = "https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v1.1.json"
eval_path = keras.utils.get_file("eval.json", eval_data_url)
class SquadExample:
def __init__(self, question, context, start_char_idx, answer_text, all_answers):
self.question = question
self.context = context
self.start_char_idx = start_char_idx
self.answer_text = answer_text
self.all_answers = all_answers
self.skip = False
def preprocess(self):
context = self.context
question = self.question
answer_text = self.answer_text
start_char_idx = self.start_char_idx
# Clean context, answer and question
context = " ".join(str(context).split())
question = " ".join(str(question).split())
answer = " ".join(str(answer_text).split())
# Find end character index of answer in context
end_char_idx = start_char_idx + len(answer)
if end_char_idx >= len(context):
self.skip = True
return
# Mark the character indexes in context that are in answer
is_char_in_ans = [0] * len(context)
for idx in range(start_char_idx, end_char_idx):
is_char_in_ans[idx] = 1
# Tokenize context
tokenized_context = tokenizer.encode(context)
# Find tokens that were created from answer characters
ans_token_idx = []
for idx, (start, end) in enumerate(tokenized_context.offsets):
if sum(is_char_in_ans[start:end]) > 0:
ans_token_idx.append(idx)
if len(ans_token_idx) == 0:
self.skip = True
return
# Find start and end token index for tokens from answer
start_token_idx = ans_token_idx[0]
end_token_idx = ans_token_idx[-1]
# Tokenize question
tokenized_question = tokenizer.encode(question)
# Create inputs
input_ids = tokenized_context.ids + tokenized_question.ids[1:]
token_type_ids = [0] * len(tokenized_context.ids) + [1] * len(
tokenized_question.ids[1:]
)
attention_mask = [1] * len(input_ids)
# Pad and create attention masks.
# Skip if truncation is needed
padding_length = max_len - len(input_ids)
if padding_length > 0: # pad
input_ids = input_ids + ([0] * padding_length)
attention_mask = attention_mask + ([0] * padding_length)
token_type_ids = token_type_ids + ([0] * padding_length)
elif padding_length < 0: # skip
self.skip = True
return
self.input_ids = input_ids
self.token_type_ids = token_type_ids
self.attention_mask = attention_mask
self.start_token_idx = start_token_idx
self.end_token_idx = end_token_idx
self.context_token_to_char = tokenized_context.offsets
def create_squad_examples(raw_data):
squad_examples = []
for item in raw_data["data"]:
for para in item["paragraphs"]:
context = para["context"]
for qa in para["qas"]:
question = qa["question"]
answer_text = qa["answers"][0]["text"]
all_answers = [_["text"] for _ in qa["answers"]]
start_char_idx = qa["answers"][0]["answer_start"]
squad_eg = SquadExample(
question, context, start_char_idx, answer_text, all_answers
)
squad_eg.preprocess()
squad_examples.append(squad_eg)
return squad_examples
def create_inputs_targets(squad_examples):
dataset_dict = {
"input_ids": [],
"token_type_ids": [],
"attention_mask": [],
"start_token_idx": [],
"end_token_idx": [],
}
for item in squad_examples:
if item.skip == False:
for key in dataset_dict:
dataset_dict[key].append(getattr(item, key))
for key in dataset_dict:
dataset_dict[key] = np.array(dataset_dict[key])
x = [
dataset_dict["input_ids"],
dataset_dict["token_type_ids"],
dataset_dict["attention_mask"],
]
y = [dataset_dict["start_token_idx"], dataset_dict["end_token_idx"]]
return x, y
with open(train_path) as f:
raw_train_data = json.load(f)
with open(eval_path) as f:
raw_eval_data = json.load(f)
train_squad_examples = create_squad_examples(raw_train_data)
x_train, y_train = create_inputs_targets(train_squad_examples)
print(f"{len(train_squad_examples)} training points created.")
eval_squad_examples = create_squad_examples(raw_eval_data)
x_eval, y_eval = create_inputs_targets(eval_squad_examples)
print(f"{len(eval_squad_examples)} evaluation points created.")
'''定义回调函数'''
def normalize_text(text):
text = text.lower()
# Remove punctuations
exclude = set(string.punctuation)
text = "".join(ch for ch in text if ch not in exclude)
# Remove articles
regex = re.compile(r"\b(a|an|the)\b", re.UNICODE)
text = re.sub(regex, " ", text)
# Remove extra white space
text = " ".join(text.split())
return text
class ExactMatch(keras.callbacks.Callback):
"""
Each `SquadExample` object contains the character level offsets for each token
in its input paragraph. We use them to get back the span of text corresponding
to the tokens between our predicted start and end tokens.
All the ground-truth answers are also present in each `SquadExample` object.
We calculate the percentage of data points where the span of text obtained
from model predictions matches one of the ground-truth answers.
"""
def __init__(self, x_eval, y_eval):
self.x_eval = x_eval
self.y_eval = y_eval
def on_epoch_end(self, epoch, logs=None):
pred_start, pred_end = self.model.predict(self.x_eval)
count = 0
eval_examples_no_skip = [_ for _ in eval_squad_examples if _.skip == False]
for idx, (start, end) in enumerate(zip(pred_start, pred_end)):
squad_eg = eval_examples_no_skip[idx]
offsets = squad_eg.context_token_to_char
start = np.argmax(start)
end = np.argmax(end)
if start >= len(offsets):
continue
pred_char_start = offsets[start][0]
if end < len(offsets):
pred_char_end = offsets[end][1]
pred_ans = squad_eg.context[pred_char_start:pred_char_end]
else:
pred_ans = squad_eg.context[pred_char_start:]
normalized_pred_ans = normalize_text(pred_ans)
normalized_true_ans = [normalize_text(_) for _ in squad_eg.all_answers]
if normalized_pred_ans in normalized_true_ans:
count += 1
acc = count / len(self.y_eval[0])
print(f"\nepoch={epoch+1}, exact match score={acc:.2f}")
'''定义模型'''
def create_model():
## BERT encoder
# encoder = TFBertModel.from_pretrained("./input/uncased_L-12_H-768_A-12/")
encoder = TFBertModel.from_pretrained(pretrained_path,from_pt=True, config=config)
## QA Model
input_ids = layers.Input(shape=(max_len,), dtype=tf.int32)
token_type_ids = layers.Input(shape=(max_len,), dtype=tf.int32)
attention_mask = layers.Input(shape=(max_len,), dtype=tf.int32)
embedding = encoder(
input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask
)[0]
start_logits = layers.Dense(1, name="start_logit", use_bias=False)(embedding)
start_logits = layers.Flatten()(start_logits)
end_logits = layers.Dense(1, name="end_logit", use_bias=False)(embedding)
end_logits = layers.Flatten()(end_logits)
start_probs = layers.Activation(keras.activations.softmax)(start_logits)
end_probs = layers.Activation(keras.activations.softmax)(end_logits)
model = keras.Model(
inputs=[input_ids, token_type_ids, attention_mask],
outputs=[start_probs, end_probs],
)
loss = keras.losses.SparseCategoricalCrossentropy(from_logits=False)
optimizer = keras.optimizers.Adam(lr=5e-5)
model.compile(optimizer=optimizer, loss=[loss, loss])
return model
model = create_model()
'''模型训练'''
exact_match_callback = ExactMatch(x_eval, y_eval)
model.fit(
x_train[:1000],
y_train[:1000],
epochs=1, # For demonstration, 3 epochs are recommended
verbose=2,
batch_size=1,
callbacks=[exact_match_callback],
)