“基于常识知识的推理问题”源代码分析-最后总结

2021SC@SDUSC

根据前面数周的描述,我们已经对DrFact这个模型有了相当程度的了解。我们不仅通过对其源代码的解析,认识到了这个模型的算法究竟如何,同时也在此过程中了解了许多有关于机器学习、深度学习以及NLP相关的知识。在这次源代码分析中,我将对最后一个源文件进行分析,在这个过程中,我们将会对于DrFact模型完整的流程有一个更加详尽的认知。

一、run_drfact.py源文件代码分析

这次源代码分析的主体,在于run_drfact.py这个源文件,在这个源文件中,定义了许多类以及方法,其源代码行数也是所有源文件中最长的。可想而知,这个源文件在整个模型训练过程中的重要意义。那么接下来就是分析时间。

1.1 调用模块

毫无疑问,作为课题项目的最后收尾,本次调用的模块可谓最多。不仅有我们常用的基础模块,还有老熟人absl,numpy,tf等。不仅如此,这次还用到了albert和bert编码模块,这次的编码模块将不会再借鉴DrKit而是直接调用完整版来进行使用。

不过,话虽如此,我们依旧会在本次调用到我们已经分析过的DrKit模块内容search_utils以及部分我们在DrFact模型中之前介绍的模块。

import collections
import functools
import json
import os
import re
import time

from absl import flags
from albert import tokenization as albert_tokenization
from bert import modeling
from bert import optimization
from bert import tokenization as bert_tokenization
from language.labs.drfact import evaluate
from language.labs.drfact import input_fns
from language.labs.drfact import model_fns
from language.labs.drkit import search_utils
import numpy as np
import random
import tensorflow.compat.v1 as tf
# from tfdeterminism import patch
# patch()

from tensorflow.contrib import cluster_resolver as contrib_cluster_resolver
from tensorflow.contrib import memory_stats as contrib_memory_stats

1.2 flags参数

接下来是对于flags参数的描述,这次定义的参数数量也是前所未有的多,具体我就不一一描述了,不然多少有冗余之嫌。不过我相信,我贴出来的代码已经足以看的很清晰了。

FLAGS = flags.FLAGS

## Required parameters
flags.DEFINE_string(
    "bert_config_file", None,
    "The config json file corresponding to the pre-trained BERT model. "
    "This specifies the model architecture.")
flags.DEFINE_string("tokenizer_type", "bert_tokenization",
                    "The tokenizier type that the BERT model was trained on.")
flags.DEFINE_string("tokenizer_model_file", None,
                    "The tokenizier model that the BERT was trained with.")

flags.DEFINE_string("vocab_file", None,
                    "The vocabulary file that the BERT model was trained on.")

flags.DEFINE_string(
    "output_dir", None,
    "The output directory where the model checkpoints will be written.")

flags.DEFINE_string(
    "output_prediction_file", "test_predictions.json",
    "The output directory where the model checkpoints will be written.")

## Other parameters
flags.DEFINE_string("train_file", None, "JSON for training.")

flags.DEFINE_string("predict_file", None, "JSON for predictions.")
flags.DEFINE_string("predict_prefix", "dev", "JSON for predictions.")

flags.DEFINE_string("test_file", None, "JSON for predictions.")

flags.DEFINE_string("data_type", "onehop",
                    "Whether queries are `onehop` or `twohop`.")

flags.DEFINE_string("model_type", "drfact",
                    "Whether to use `drfact` or `drkit` model.")

flags.DEFINE_string(
    "init_checkpoint", None,
    "Initial checkpoint (usually from a pre-trained BERT model).")

flags.DEFINE_string("train_data_dir", None,
                    "Location of entity/mention/fact files for training data.")

flags.DEFINE_string("f2f_index_dir", None,
                    "Location of fact2fact files for training data.")

flags.DEFINE_string("test_data_dir", None,
                    "Location of entity/mention/fact files for test data.")

flags.DEFINE_string("model_ckpt_toload", "best_model",
                    "Name of the checkpoints.")

flags.DEFINE_string("test_model_ckpt", "best_model", "Name of the checkpoints.")

flags.DEFINE_string("embed_index_prefix", "bert_large", "Prefix of indexes.")

flags.DEFINE_integer("num_hops", 2, "Number of hops in rule template.")

flags.DEFINE_integer("max_entity_len", 4,
                     "Maximum number of tokens in an entity name.")

flags.DEFINE_integer(
    "num_mips_neighbors", 100,
    "Number of nearest neighbor mentions to retrieve for queries in each hop.")

flags.DEFINE_bool(
    "do_lower_case", True,
    "Whether to lower case the input text. Should be True for uncased "
    "models and False for cased models.")

flags.DEFINE_integer(
    "projection_dim", None, "Number of dimensions to project embeddings to. "
    "Set to None to use full dimensions.")

flags.DEFINE_integer(
    "max_query_length", 64,
    "The maximum number of tokens for the question. Questions longer than "
    "this will be truncated to this length.")

flags.DEFINE_bool("do_train", False, "Whether to run training.")

flags.DEFINE_bool("do_predict", False, "Whether to run eval on the dev set.")

flags.DEFINE_bool("do_test", False, "Whether to run eval on the test set.")

flags.DEFINE_float(
    "subject_mention_probability", 0.0,
    "Fraction of training instances for which we use subject "
    "mentions in the text as opposed to canonical names.")

flags.DEFINE_integer("train_batch_size", 16, "Total batch size for training.")

flags.DEFINE_int
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值