2021SC@SDUSC
根据前面数周的描述,我们已经对DrFact这个模型有了相当程度的了解。我们不仅通过对其源代码的解析,认识到了这个模型的算法究竟如何,同时也在此过程中了解了许多有关于机器学习、深度学习以及NLP相关的知识。在这次源代码分析中,我将对最后一个源文件进行分析,在这个过程中,我们将会对于DrFact模型完整的流程有一个更加详尽的认知。
一、run_drfact.py源文件代码分析
这次源代码分析的主体,在于run_drfact.py这个源文件,在这个源文件中,定义了许多类以及方法,其源代码行数也是所有源文件中最长的。可想而知,这个源文件在整个模型训练过程中的重要意义。那么接下来就是分析时间。
1.1 调用模块
毫无疑问,作为课题项目的最后收尾,本次调用的模块可谓最多。不仅有我们常用的基础模块,还有老熟人absl,numpy,tf等。不仅如此,这次还用到了albert和bert编码模块,这次的编码模块将不会再借鉴DrKit而是直接调用完整版来进行使用。
不过,话虽如此,我们依旧会在本次调用到我们已经分析过的DrKit模块内容search_utils以及部分我们在DrFact模型中之前介绍的模块。
import collections
import functools
import json
import os
import re
import time
from absl import flags
from albert import tokenization as albert_tokenization
from bert import modeling
from bert import optimization
from bert import tokenization as bert_tokenization
from language.labs.drfact import evaluate
from language.labs.drfact import input_fns
from language.labs.drfact import model_fns
from language.labs.drkit import search_utils
import numpy as np
import random
import tensorflow.compat.v1 as tf
# from tfdeterminism import patch
# patch()
from tensorflow.contrib import cluster_resolver as contrib_cluster_resolver
from tensorflow.contrib import memory_stats as contrib_memory_stats
1.2 flags参数
接下来是对于flags参数的描述,这次定义的参数数量也是前所未有的多,具体我就不一一描述了,不然多少有冗余之嫌。不过我相信,我贴出来的代码已经足以看的很清晰了。
FLAGS = flags.FLAGS
## Required parameters
flags.DEFINE_string(
"bert_config_file", None,
"The config json file corresponding to the pre-trained BERT model. "
"This specifies the model architecture.")
flags.DEFINE_string("tokenizer_type", "bert_tokenization",
"The tokenizier type that the BERT model was trained on.")
flags.DEFINE_string("tokenizer_model_file", None,
"The tokenizier model that the BERT was trained with.")
flags.DEFINE_string("vocab_file", None,
"The vocabulary file that the BERT model was trained on.")
flags.DEFINE_string(
"output_dir", None,
"The output directory where the model checkpoints will be written.")
flags.DEFINE_string(
"output_prediction_file", "test_predictions.json",
"The output directory where the model checkpoints will be written.")
## Other parameters
flags.DEFINE_string("train_file", None, "JSON for training.")
flags.DEFINE_string("predict_file", None, "JSON for predictions.")
flags.DEFINE_string("predict_prefix", "dev", "JSON for predictions.")
flags.DEFINE_string("test_file", None, "JSON for predictions.")
flags.DEFINE_string("data_type", "onehop",
"Whether queries are `onehop` or `twohop`.")
flags.DEFINE_string("model_type", "drfact",
"Whether to use `drfact` or `drkit` model.")
flags.DEFINE_string(
"init_checkpoint", None,
"Initial checkpoint (usually from a pre-trained BERT model).")
flags.DEFINE_string("train_data_dir", None,
"Location of entity/mention/fact files for training data.")
flags.DEFINE_string("f2f_index_dir", None,
"Location of fact2fact files for training data.")
flags.DEFINE_string("test_data_dir", None,
"Location of entity/mention/fact files for test data.")
flags.DEFINE_string("model_ckpt_toload", "best_model",
"Name of the checkpoints.")
flags.DEFINE_string("test_model_ckpt", "best_model", "Name of the checkpoints.")
flags.DEFINE_string("embed_index_prefix", "bert_large", "Prefix of indexes.")
flags.DEFINE_integer("num_hops", 2, "Number of hops in rule template.")
flags.DEFINE_integer("max_entity_len", 4,
"Maximum number of tokens in an entity name.")
flags.DEFINE_integer(
"num_mips_neighbors", 100,
"Number of nearest neighbor mentions to retrieve for queries in each hop.")
flags.DEFINE_bool(
"do_lower_case", True,
"Whether to lower case the input text. Should be True for uncased "
"models and False for cased models.")
flags.DEFINE_integer(
"projection_dim", None, "Number of dimensions to project embeddings to. "
"Set to None to use full dimensions.")
flags.DEFINE_integer(
"max_query_length", 64,
"The maximum number of tokens for the question. Questions longer than "
"this will be truncated to this length.")
flags.DEFINE_bool("do_train", False, "Whether to run training.")
flags.DEFINE_bool("do_predict", False, "Whether to run eval on the dev set.")
flags.DEFINE_bool("do_test", False, "Whether to run eval on the test set.")
flags.DEFINE_float(
"subject_mention_probability", 0.0,
"Fraction of training instances for which we use subject "
"mentions in the text as opposed to canonical names.")
flags.DEFINE_integer("train_batch_size", 16, "Total batch size for training.")
flags.DEFINE_int