上回咱们说到项目的目录以及各个文件的作用,这回我们自顶向下开始分析该项目。项目开始于tf_train_ctc.py文件。
1. 训练的开始
代码如下图所示:
if __name__ == '__main__':
import click
# Use click to parse command line arguments
@click.command()
@click.option('--config', default='neural_network.ini', help='Configuration file name')
@click.option('--name', default=None, help='Model name for logging')
@click.option('--debug', type=bool, default=False,
help='Use debug settings in config file')
# Train RNN model using a given configuration file
def main(config='neural_network.ini', name=None, debug=False):
logging.basicConfig(level=logging.DEBUG,
format='%(asctime)s [%(levelname)s] %(name)s: %(message)s')
global logger
logger = logging.getLogger(os.path.basename(__file__))
# create the Tf_train_ctc class
tf_train_ctc = Tf_train_ctc(
config_file=config, model_name=name, debug=debug)
# run the training
tf_train_ctc.run_model()
main()
可以看到,该段代码主要初始化了Tf_train_ctc这个类,然后调用该类的run_model()函数训练模型。下面我们细细分解该类。
2.Tf_train_ctc类的组成
该类所包含的方法如下图所示:
方法名 | 主要作用 |
---|---|
init | 通过neural_network.ini初始化model |
load_configs | 读取neural_network.ini中的配置信息 |
set_up_directories | 设置session/summary数据存放目录 |
set_up_model | 获得训练所需数据以及配置 |
run_model | 执行模型训练相关操作 |
setup_network_and_graph | 配置网络的输入输出 |
load_placeholder_into_network | 构建SimpleLSTM/BiRNN网络 |
setup_loss_function | 设置网络的ctc_loss函数 |
setup_optimizer | 使用AdamOptimizer优化 |
setup_decoder | 配置网络输出的decoder |
setup_summary_statistics | 使用tensorboard读取训练时的相关信息 |
run_training_epochs | 每次训练迭代调用的函数 |
run_validation_step | 验证网络调用的函数 |
validation_and_checkpoint_check | 设置保存模型参数时间点以及验证模型参数的时间点 |
run_batches | 运行一个batch调用的函数 |
3. Tf_train_ctc类中函数调用关系
由前面可以知道Tf_train_ctc类中最主要的两个函数是init 和run_model,这两个函数通过调用该类中其他的函数分别实现模型的初始化和训练,其调用关系如下所示:
其中configuration包含以下函数:
其中run_training_epochs调用如下函数:
本回简要介绍了Tf_train_ctc内部的函数调用关系图,给出了该训练模型的骨架。下回我们结合具体代码介绍模型是如何初始化的。