基于Bert预训练模型的SQuAD 问答系统
step-1 运行example
参考huggingface的 pytorch_transformer 下载并运行 example run_squad.py
运行参数:
python run_squad.py
--model_type bert
--model_name_or_path bert-base-uncased
--do_train
--do_eval
--do_lower_case
--train_file ../../SQUAD_DIR/train-v1.1.json
--predict_file ../../SQUAD_DIR/dev-v1.1.json
--per_gpu_train_batch_size 4
--learning_rate 3e-5
--num_train_epoch 2.0
--max_seq_length 384
--doc_stride 128
--output_dir ../../SQUAD_DIR/OUTPUT
--overwrite_output_dir
--gradient_accumulation_steps 3
其中config.json
{
'attention_probs_dropout_prob': 0.1, 'finetuning_task': None, 'hidden_act': 'gelu', 'hidden_dropout_prob': 0.1, 'hidden_size': 768, 'initializer_range': 0.02, 'intermediate_size': 3072, 'layer_norm_eps': 1e-12, 'max_position_embeddings': 512, 'num_attention_heads': 12, 'num_hidden_layers': 12, 'num_labels': 2, 'output_attentions': False, 'output_hidden_states': False, 'pruned_heads': {
}, 'torchscript': False, 'type_vocab_size': 2, 'vocab_size': 30522}
tokenizer_config.json
{
'do_lower_case': True, 'max_len': 512, 'init_inputs': []}
运行结果:
{
"exact": 79.2715, "f1": 86.96, "total": 10570, "HasAns_exact": 79.27, "HasAns_f1": 86.96, "HasAns_total": 10570}
evaluation
python evaluate-v1.1.py dev-v1.1.json OUTPUT/predictions_.json
{
"exact_match": 79.27152317880795, "f1": 86.96144570829648}
step-2 代码解读
1. SQuAD 数据格式
SQuAD: json file
json file: dict_keys(['data', 'version'])其中:
1. 'version'是1.1表明SQuAD 版本
2. 'data'是训练数据:<class 'list'> 是一个list,里面包含442项,每一项都是一个dict.
data[11]是一个dict_keys(['title', 'paragraphs']) ##以data[11]为例
2.1 'title'是文章的标题
2.2'paragraphs'是文章的各个段落,是一个<class 'list'>,里面包含n(148项),每一项都是一个dict
paragraphs[0]是一个dict_keys(['context', 'qas'])
2.2.1 'context'是该段落内容
2.2.2 'qas'是一个list, 里面有6项该段落对应的问题,每一项都是一个dict,
qas里面的问题有些比较相似,答案都一致的,同一个答案的不同问法
qas[0]是一个dict_keys(['answers', 'question', 'id'])
2.2.2.1 answers: 是一个list,里面每个元素都是一个dict_keys(['answer_start','text'])
2.2.2.2 question: 问题
2.2.2.3 id: 问题的id
比如:
{
'answers': [{
'answer_start': 0, 'text': 'New York'}],
'question': 'What city in the United States has the highest population?',
'id': '56ce304daab44d1400b8850e'}
{
'answers': [{
'answer_start': 0, 'text': 'New York'}],
'question': 'In what city is the United Nations based?',
'id': '56ce304daab44d1400b8850f'}
{
'answers': [{
'answer_start': 0, 'text': 'New York'}],
'question': 'What city has been called the cultural capital of the world?',
'id': '56ce304daab44d1400b88510'}
{
'answers': [{
'answer_start': 0, 'text': 'New York'}],
'question': 'What American city welcomes the largest number of legal immigrants?',
'id': '56ce304daab44d1400b88511'}
{
'answers': [{
'answer_start': 22, 'text': 'New York City'}],
'question': 'The major gateway for immigration has been which US city?',
'id': '56cf5d41aab44d1400b89130'}
{
'answers': [{
'answer_start': 22, 'text': 'New York City'}],
'question': 'The most populated city in the United States is which city?',
'id': '56cf5d41aab44d1400b89131'}
2. 数据读取和转换
2-1. 源数据结构 SquadExample
SquadExample 用于parser SQuAD 源数据。
class SquadExample(object):
"""
A single traing/test example for the squad dataset.
For examples without an answer, the start and end position are -1.
"""
def __init__(self, qas_id, question_text, doc_tokens,
orig_answer_text=None, start_position=None, end_position=None,
is_impossible=None):
self.qas_id = qas_id
self.question_text = question_text
self.doc_tokens = doc_tokens
self.orig_answer_text = orig_answer_text
self.start_position = start_position
self.end_position = end_position
self.is_impossible = is_impossible
其中qas_id 是样本ID, question_text 问题文本,doc_tokens是阅读材料, orig_answer_text 原始答案的文本, start_position答案在文本中开始的位置,end_position答案在文本中结束的位置,is_impossible在SQuAD2中可用的negtivate 标识(这里可以先不用管)。
2-2. 源数据读取read_squad_examples
def read_squad_examples(input_file, is_training, version_2_with_negative):
#read a SQuAD json file into a list of SquadExample
with open(input_file, "r", encoding='utf-8') as reader:
input_data = json.load(reader)['data']
def is_whitespace(c):
if c==' ' or c=='\t' or c=='\r' or c=='\n' or ord(c)==0x202F:
return True
return False
examples = []
for entry in input_data:## entry是一个dict{title, paragraphs},