解读pytorch_transformer examples之 run_squad.py

基于Bert预训练模型的SQuAD 问答系统

step-1 运行example

参考huggingface的 pytorch_transformer 下载并运行 example run_squad.py

运行参数:

python run_squad.py 
	--model_type bert 
	--model_name_or_path bert-base-uncased 
	--do_train
	--do_eval
	--do_lower_case 
	--train_file ../../SQUAD_DIR/train-v1.1.json 
	--predict_file ../../SQUAD_DIR/dev-v1.1.json 
	--per_gpu_train_batch_size 4 
	--learning_rate 3e-5 
	--num_train_epoch 2.0 
	--max_seq_length 384 
	--doc_stride 128 
	--output_dir ../../SQUAD_DIR/OUTPUT 
	--overwrite_output_dir
	--gradient_accumulation_steps 3

其中config.json

 
{
   'attention_probs_dropout_prob': 0.1, 'finetuning_task': None, 'hidden_act': 'gelu', 'hidden_dropout_prob': 0.1, 'hidden_size': 768, 'initializer_range': 0.02, 'intermediate_size': 3072, 'layer_norm_eps': 1e-12, 'max_position_embeddings': 512, 'num_attention_heads': 12, 'num_hidden_layers': 12, 'num_labels': 2, 'output_attentions': False, 'output_hidden_states': False, 'pruned_heads': {
   }, 'torchscript': False, 'type_vocab_size': 2, 'vocab_size': 30522}

tokenizer_config.json

{
   'do_lower_case': True, 'max_len': 512, 'init_inputs': []}

运行结果:

 {
   "exact": 79.2715, "f1": 86.96, "total": 10570, "HasAns_exact": 79.27, "HasAns_f1": 86.96, "HasAns_total": 10570}

evaluation

python evaluate-v1.1.py  dev-v1.1.json OUTPUT/predictions_.json 

{
   "exact_match": 79.27152317880795, "f1": 86.96144570829648}
step-2 代码解读
1. SQuAD 数据格式
 SQuAD: json file 
     json file: dict_keys(['data', 'version'])其中:
     1. 'version'1.1表明SQuAD 版本
     2. 'data'是训练数据:<class 'list'> 是一个list,里面包含442项,每一项都是一个dict.
         data[11]是一个dict_keys(['title', 'paragraphs']) ##以data[11]为例
         2.1 'title'是文章的标题
         2.2'paragraphs'是文章的各个段落,是一个<class 'list'>,里面包含n(148),每一项都是一个dict
             paragraphs[0]是一个dict_keys(['context', 'qas'])
             2.2.1 'context'是该段落内容
             2.2.2 'qas'是一个list, 里面有6项该段落对应的问题,每一项都是一个dict,
                 qas里面的问题有些比较相似,答案都一致的,同一个答案的不同问法
                 qas[0]是一个dict_keys(['answers', 'question', 'id'])
                 2.2.2.1 answers: 是一个list,里面每个元素都是一个dict_keys(['answer_start','text'])
                 2.2.2.2 question: 问题
                 2.2.2.3 id: 问题的id
                 比如:
                     {
   'answers': [{
   'answer_start': 0, 'text': 'New York'}], 
                     'question': 'What city in the United States has the highest population?',
                     'id': '56ce304daab44d1400b8850e'}
                     {
   'answers': [{
   'answer_start': 0, 'text': 'New York'}], 
                     'question': 'In what city is the United Nations based?',
                     'id': '56ce304daab44d1400b8850f'}
                     {
   'answers': [{
   'answer_start': 0, 'text': 'New York'}],
                     'question': 'What city has been called the cultural capital of the world?', 
                     'id': '56ce304daab44d1400b88510'}
                     {
   'answers': [{
   'answer_start': 0, 'text': 'New York'}], 
                     'question': 'What American city welcomes the largest number of legal immigrants?',
                     'id': '56ce304daab44d1400b88511'}
                     {
   'answers': [{
   'answer_start': 22, 'text': 'New York City'}], 
                     'question': 'The major gateway for immigration has been which US city?', 
                     'id': '56cf5d41aab44d1400b89130'}
                     {
   'answers': [{
   'answer_start': 22, 'text': 'New York City'}], 
                     'question': 'The most populated city in the United States is which city?', 
                     'id': '56cf5d41aab44d1400b89131'} 
2. 数据读取和转换
2-1. 源数据结构 SquadExample

SquadExample 用于parser SQuAD 源数据。

class SquadExample(object):
    """
    A single traing/test example for the squad dataset.
    For examples without an answer, the start and end position are -1.
    """
    def __init__(self, qas_id, question_text, doc_tokens, 
                 orig_answer_text=None, start_position=None, end_position=None,
                 is_impossible=None):
        self.qas_id = qas_id
        self.question_text = question_text
        self.doc_tokens = doc_tokens
        self.orig_answer_text = orig_answer_text
        self.start_position = start_position
        self.end_position = end_position
        self.is_impossible = is_impossible

其中qas_id 是样本ID, question_text 问题文本,doc_tokens是阅读材料, orig_answer_text 原始答案的文本, start_position答案在文本中开始的位置,end_position答案在文本中结束的位置,is_impossible在SQuAD2中可用的negtivate 标识(这里可以先不用管)。

2-2. 源数据读取read_squad_examples
def read_squad_examples(input_file, is_training, version_2_with_negative):
    #read a SQuAD json file into  a list of SquadExample
    with open(input_file, "r", encoding='utf-8') as reader:
        input_data = json.load(reader)['data']
    
    def is_whitespace(c):
        if c==' ' or c=='\t' or c=='\r' or c=='\n' or ord(c)==0x202F:
            return True
        return False
    
    examples = []
    for entry in input_data:## entry是一个dict{title, paragraphs}, 
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值