2021SC@SDUSC
前篇背景介绍:前几周的源代码分析中,我们已经了解了drfact是如何对语料库进行预处理的,也了解了drfact模型算法的前几步都做了什么事情。但这一周的源代码分析我不会对具体的源代码进行分析,原因在于我在本周进行源代码分析,并回顾了过往的源代码分析内容时,注意到drfact模型对其他模型进行了一定程度的借鉴,这一点尤其体现在其核心源代码之中——调用了其他模型中编写好的函数。因此,体现在源代码之中的内容也就不再是仅仅只要关注到drfact这一个项目包即可,而是需要对整个OpenCSR项目的其他源代码也进行审视。
承接上文,在上一周的源代码分析中,我主要描述了DrFact模型与DrKit模型之间的勾连借鉴。这一点主要体现在了DrFact模型的各个模块对于DrKit模型中某些模块的调用,其中尤其以input_fns.py和model.fns.py源文件为典型,在这两个模型中,都出现了名字如上述一般的源文件,由此可见二者在功能定位上应该有相近之处。此外,在对DrFact模型中模块的分析中我们可以认识到,有些来自于DrKit模型的函数在DrFact模型中被反复用到,且是跨越多个模块被调用,这让我们可以确定这些函数具有重要的分析价值。因此,在本周的这篇源代码分析中,我将主要阐述这些DrKit模型中的函数定义,希望通过更加细微的刻画来展现这些函数的具体用途。
二、DrKit模型的已定义函数
2.1 BERT与bert_utils_v2.py模块
要提及DrKit模型中对DrFact模型中产生贡献的模块,首先要谈到BERT。BERT在前面的源代码分析中有提及过,它的全称为Bidirectional Encoder Representation from Transformers,是一个预训练的语言表征模型。它强调了不再像以往一样采用传统的单向语言模型或者把两个单向语言模型进行浅层拼接的方法进行预训练,而是采用新的masked language model(MLM),以致能生成深度的双向语言表征。
有了上文做一些铺垫,我们可以转头去看有关DrKit模型中定义的有关BERT的模块。有上一篇的源码分析内容可以很容易知道,与BERT最直接也是唯一相关的一个模块是bert_utils_v2.py这个模块,在这个模块中定义有一个BERTPredictor类。根据其注释易知,这个类就是一个封装了BERT模型的编码器。
BERTPredictor类的构造函数如下所示,可以很清晰的知道,这个类共有七个成员变量,其中的序列最大长度max_seq_length,查询最大长度max_qry_length,实体最大长度max_seq_length,词向量大小emb_dim和批大小(这里的batch即是神经网络模型中的那个)batch_size都是通过源代码中的flags定义的参数传进来的,而分词器tokenizer则是通过构造函数的形参传进来的。
此外,在构造函数的形参中还有一个缺省值为None的形参estimator,当它是None时,即没有一个明确的estimator时,构造函数则会根据flags定义的bert_config_file参数指向的文件为bert模型配置参数,而运行配置run_config则采用tensorflow中的estimator配置,至于QA系统的配置则使用DrKit中写好的run_dualencoder_qa模块中的QAConfig()函数,传入flags定义好的参数,从而形成这里的QA系统配置(由于继续深挖下去就没完没了了,也不属于核心代码范畴,因此对于run_dualencoder_qa模块就不细究了,不过望文生义来说,这个模块应该是用来运行QA系统的双工编码器的)。model_fn亦同理,采用了DrKit中定义好的模块内容进行初始化,最后根据上述model_fn和配置内容,通过tensorflow中TUREstimator构造函数来构建estimator,再将这个estimator作为FastPredictor的参数,构建出一个BERTPredictor类的最后一个成员变量fast_predictor,由此,这样一个BERT模型就构建完成了。
class BERTPredictor:
"""Wrapper around a BERT model to encode text."""
def __init__(self, tokenizer, init_checkpoint, estimator=None):
"""Setup BERT model."""
self.max_seq_length = FLAGS.max_seq_length
self.max_qry_length = FLAGS.max_query_length
self.max_ent_length = FLAGS.max_entity_length
self.batch_size = FLAGS.predict_batch_size
self.tokenizer = tokenizer
if estimator is None:
bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file)
run_config = tf.estimator.tpu.RunConfig()
qa_config = run_dualencoder_qa.QAConfig(
doc_layers_to_use=FLAGS.doc_layers_to_use,
doc_aggregation_fn=FLAGS.doc_aggregation_fn,
qry_layers_to_use=FLAGS.qry_layers_to_use,
qry_aggregation_fn=FLAGS.qry_aggregation_fn,
projection_dim=FLAGS.projection_dim,
normalize_emb=FLAGS.normalize_emb,
share_bert=True,
exclude_scopes=None)
model_fn_builder = run_dualencoder_qa.model_fn_builder
model_fn = model_fn_builder(
bert_config=bert_config,
qa_config=qa_config,
init_checkpoint=init_checkpoint,
learning_rate=0.0,
num_train_steps=0,
num_warmup_steps=0,
use_tpu=False,
use_one_hot_embeddings=False)
# If TPU is not available, this will fall back to normal Estimator on CPU
# or GPU.
estimator = tf.estimator.tpu.TPUEstimator(
use_tpu=False,
model_fn=model_fn,
config=run_config,
train_batch_size=self.batch_size,
predict_batch_size=self.batch_size)
self.fast_predictor = FastPredict(estimator, self.get_input_fn)
self.emb_dim = FLAGS.projection_dim
此外,该类中还有五个成员函数,get_input_fn()函数用于返回一个接受了生成器generator的input_fn,_run_on_features()函数则是根据输入特征进行预测,get_features()函数则是将输入的一系列分词转化为一个特征字典的形式,get_doc_embedding()函数显然是用BERT模型来获取文档中的词向量的,而get_qry_embedding()函数则是用BERT模型来获取问题中的词向量。
其函数定义如下:
def get_input_fn(self, generator):
"""Return an input_fn which accepts a generator."""
def _input_fn(params):
"""Convert input into features."""
del params
seq_length = self.max_seq_length
qry_length = self.max_qry_length
ent_length = self.max_ent_length
d = tf.data.Dataset.from_generator(
generator,
output_types={
"unique_ids": tf.int32,
"doc_input_ids": tf.int32,
"doc_input_mask": tf.int32,
"doc_segment_ids": tf.int32,
"qry_input_ids": tf.int32,
"qry_input_mask": tf.int32,
"qry_segment_ids": tf.int32,
"ent_input_ids": tf.int32,
"ent_input_mask": tf.int32,
},
output_shapes={
"unique_ids": tf.TensorShape([]),
"doc_input_ids": tf.TensorShape([seq_length]),
"doc_input_mask": tf.TensorShape([seq_length]),
"doc_segment_ids": tf.TensorShape([seq_length]),
"qry_input_ids": tf.TensorShape([qry_length]),
"qry_input_mask": tf.TensorShape([qry_length]),
"qry_segment_ids": tf.TensorShape([qry_length]),
"ent_input_ids": tf.TensorShape([ent_length]),
"ent_input_mask": tf.TensorShape([ent_length]),
})
d = d.batch(batch_size=self.batch_size)
return d
return _input_fn
def _run_on_features(self, features):
"""Run predictions for given features."""
current_size = len(features)
if current_size < self.batch_size:
features += [features[-1]] * (self.batch_size - current_size)
return self.fast_predictor.predict(features)[:current_size]
def get_features(self, doc_tokens, qry_tokens, ent_tokens, uid):
"""Convert list of tokens to a feature dict."""
max_tokens_doc = self.max_seq_length - 2
max_tokens_qry = self.max_qry_length - 2
max_tokens_ent = self.max_ent_length
doc_input_ids = self.tokenizer.convert_tokens_to_ids(
["[CLS]"] + doc_tokens[:max_tokens_doc] + ["[SEP]"])
doc_segment_ids = [1] * len(doc_input_ids)
doc_input_mask = [1] * len(doc_input_ids)
while len(doc_input_ids) < self.max_seq_length:
doc_input_ids.append(0)
doc_input_mask.append(0)
doc_segment_ids.append(0)
qry_input_ids = self.tokenizer.convert_tokens_to_ids(
["[CLS]"] + qry_tokens[:max_tokens_qry] + ["[SEP]"])
qry_segment_ids = [0] * len(qry_input_ids)
qry_input_mask = [1] * len(qry_input_ids)
while len(qry_input_ids) < self.max_qry_length:
qry_input_ids.append(0)
qry_input_mask.append(0)
qry_segment_ids.append(0)
ent_input_ids = self.tokenizer.convert_tokens_to_ids(
ent_tokens[:max_tokens_ent])
ent_input_mask = [1] * len(ent_input_ids)
while len(ent_input_ids) < self.max_ent_length:
ent_input_ids.append(0)
ent_input_mask.append(0)
return {
"unique_ids": uid,
"doc_input_ids": doc_input_ids,
"doc_input_mask": doc_input_mask,
"doc_segment_ids": doc_segment_ids,
"qry_input_ids": qry_input_ids,
"qry_input_mask": qry_input_mask,
"qry_segment_ids": qry_segment_ids,
"ent_input_ids": ent_input_ids,
"ent_input_mask": ent_input_mask,
}
def get_doc_embeddings(self, docs):
"""Run BERT to get features for docs.
Args:
docs: List of list of tokens.
Returns:
embeddings: Numpy array of token features.
"""
num_batches = (len(docs) // self.batch_size) + 1
tf.logging.info("Total batches for BERT = %d", num_batches)
embeddings = np.empty((len(docs), self.max_seq_length, self.emb_dim),
dtype=np.float32)
for nb in tqdm(range(num_batches)):
min_ = nb * self.batch_size
max_ = (nb + 1) * self.batch_size
if min_ >= len(docs):
break
if max_ > len(docs):
max_ = len(docs)
current_features = [
self.get_features(docs[ii], ["dummy"], ["dummy"], ii)
for ii in range(min_, max_)
]
results = self._run_on_features(current_features)
for ir, rr in enumerate(results):
embeddings[min_ + ir, :, :] = rr["doc_features"]
return embeddings[:, 1:, :] # remove [CLS]
def get_qry_embeddings(self, qrys, ents):
"""Run BERT to get features for queries.
Args:
qrys: List of list of tokens.
ents: List of list of tokens.
Returns:
st_embeddings: Numpy array of token features.
en_embeddings: Numpy array of token features.
bow_embeddings: Numpy array of token features.
"""
num_batches = (len(qrys) // self.batch_size) + 1
tf.logging.info("Total batches for BERT = %d", num_batches)
st_embeddings = np.empty((len(qrys), self.emb_dim), dtype=np.float32)
en_embeddings = np.empty((len(qrys), self.emb_dim), dtype=np.float32)
bow_embeddings = np.empty((len(qrys), self.emb_dim), dtype=np.float32)
for nb in tqdm(range(num_batches)):
min_ = nb * self.batch_size
max_ = (nb + 1) * self.batch_size
if min_ >= len(qrys):
break
if max_ > len(qrys):
max_ = len(qrys)
current_features = [
self.get_features(["dummy"], qrys[ii], ents[ii], ii)
for ii in range(min_, max_)
]
results = self._run_on_features(current_features)
for ir, rr in enumerate(results):
st_embeddings[min_ + ir, :] = rr["qry_st_features"]
en_embeddings[min_ + ir, :] = rr["qry_en_features"]
bow_embeddings[min_ + ir, :] = rr["qry_bow_features"]
return st_embeddings, en_embeddings, bow_embeddings
2.2 search_utils.py模块与其中的三个重要函数
在上一篇源代码分析中提到,在DrKit模型中有一个模块中的函数被非常频繁的使用,这个模块就是search_utils.py模块。在这个模块中,共有三个重要函数被DrFact模型反复使用,它们分别是create_mips_searcher()函数,write_to_checkpoint()函数和write_ragged_to_checkpoint()函数。接下来,我将对这三个函数进行一定程度的分析。
2.2.1 create_mips_searcher()函数
对于create_mips_searcher()函数,其实我们可以望文生义地知道这是一个mips(最大内积搜索)的搜索器。不过基于它的重要作用,我们还是分析一下它的代码。
在这个函数中,首先使用load_database函数获取了将要用于与输入问题进行最大内积搜索的数据库(路径等变量通过create_mips_searcher()函数的形参传递进来)。
然后使用tensorflow的control_dependencies来进行MIPS的初始化。
最后在确认DB被初始化完成后,定义内部函数_search()——通过使用tensorflow的matmul方法来获取所有距离,然后使用tensorflow.nn中的top_k方法获取最大内积的目标距离和索引(从dist这样的命名也可以看出最大内积搜索和KNN真的非常像啊!),最后对这两个生成的topk_dist和topk_idx进行一定程度的整形之后,将其返回。
完成上述操作之后,返回我们的tf_db和_search函数,就得到了我们需要的mips(最大内积搜索)搜索器。
def create_mips_searcher(var_name, checkpoint_path, num_neighbors, local_var_name="mips_init_barrier"):
"""Create searcher for returning top-k closest elements."""
tf_db = load_database(var_name, None, checkpoint_path)
with tf.control_dependencies([tf_db.initializer]):
mips_init_barrier = tf.constant(True)
# Make sure DB is initialized.
tf.get_local_variable(local_var_name, initializer=mips_init_barrier)
def _search(query):
with tf.device("/cpu:0"):
distance = tf.matmul(query, tf_db, transpose_b=True)
topk_dist, topk_idx = tf.nn.top_k(distance, num_neighbors)
topk_dist.set_shape([query.shape[0], num_neighbors])
topk_idx.set_shape([query.shape[0], num_neighbors])
return topk_dist, topk_idx
return tf_db, _search
2.2.2 write_to_checkpoint()函数
介绍完了我们的最大内积搜索器之后,接下来介绍另一个函数——write_to_checkpoint()函数。这个函数的代码不多,但却有着极为重要的功能。下面我将对此进行分析。
首先可以从注释中看出,该函数的意义在于将numpy的数组形式数据转化为checkpoint,即检查点的形式。这里,我们的numpy数组即是函数的形参np_db。
在声明使用了tensorflow.Graph()之后,即将np_db通过lambda函数(tensorflow.py_func其实就是执行后续函数,这里是保证了tensorflow的原有内容不变情况下的额外扩展接口以增强其灵活性)处理,再设置相同行列数后保存至init_value中作为tf.db的初始器参数,然后再将tf_db的列表作为tensorflow.train.Saver的构造参数,构造保存器saver。
最后,通过tensorflow的session变量,执行tensorflow.global_variables_initializer()函数,将其结果通过保存器saver保存于checkpoint_path的路径中。
def write_to_checkpoint(var_name, np_db, dtype, checkpoint_path):
"""Write np array to checkpoint."""
with tf.Graph().as_default():
init_value = tf.py_func(lambda: np_db, [], dtype, stateful=False)
init_value.set_shape(np_db.shape)
tf_db = tf.get_variable(var_name, initializer=init_value)
saver = tf.train.Saver([tf_db])
with tf.Session() as session:
session.run(tf.global_variables_initializer())
saver.save(session, checkpoint_path)
2.2.3 write_ragged_to_checkpoint()函数
除了使用write_to_checkpoint()函数以外,有些情况下则应该使用另一个函数——write_ragged_to_checkpoint()函数。这个函数的作用以及处理方式与write_to_checkpoint()函数高度相似,我们可以从下面的代码解析中看出。
首先,根据注释可以看出,这个函数处理的对象不是numpy数组,而是改成了使用CSR的矩阵,而且这里有明确的用途,即是用于加载到不规则张量中。因此,这个函数的使用将会更加受限。
目光转到定义主体,可以看出,该函数定义的前半部分与write_to_checkpoint()函数高度相似,均是使用tensorflow.py_func方法调用lambda函数处理保存数据并进行整形保持行列数。然后使用tensorflow.get_variable()函数来获取初始化的变量,最后通过saver记性保存。不过在最后,有明确的差异在于,函数最后还通过tensorflow打开文件并在末尾标注了scipy矩阵的尺寸以及相关属性。除此以外,执行方法基本与write_to_checkpoint()函数一致,因此可以不必深究其原理,只需要记住这个函数算是一个变体就好了。
def write_ragged_to_checkpoint(var_name, sp_mat, checkpoint_path):
"""Write scipy CSR matrix to checkpoint for loading to ragged tensor."""
data = sp_mat.data
indices = sp_mat.indices
rowsplits = sp_mat.indptr
with tf.Graph().as_default():
init_data = tf.py_func(
lambda: data.astype(np.float32), [], tf.float32, stateful=False)
init_data.set_shape(data.shape)
init_indices = tf.py_func(
lambda: indices.astype(np.int64), [], tf.int64, stateful=False)
init_indices.set_shape(indices.shape)
init_rowsplits = tf.py_func(
lambda: rowsplits.astype(np.int64), [], tf.int64, stateful=False)
init_rowsplits.set_shape(rowsplits.shape)
tf_data = tf.get_variable(var_name + "_data", initializer=init_data)
tf_indices = tf.get_variable(
var_name + "_indices", initializer=init_indices)
tf_rowsplits = tf.get_variable(
var_name + "_rowsplits", initializer=init_rowsplits)
saver = tf.train.Saver([tf_data, tf_indices, tf_rowsplits])
with tf.Session() as session:
session.run(tf.global_variables_initializer())
saver.save(session, checkpoint_path)
with tf.gfile.Open(checkpoint_path + ".info", "w") as f:
f.write(str(sp_mat.shape[0]) + " " + str(sp_mat.getnnz()))
除了这些模块以外,其实还有一些被DrFact模型调用到的模块,不过由于使用频率不高等原因,就不再对此过多赘述了。以上,便是这次源码分析的全部内容。