代码追踪
1、从finetune训练的脚本文件:./script/run_train.py
2、在准备参数及指定log目录后,执行训练文件:./train.py
3、在其中分别有大致几个步骤:
- 前期准备
- 制作reader(数据读取器)
- 准备model
- 制作trainer(训练器)
- 训练并验证
配置文件
路径:./config/roberta_skep_large_en.SST-2.cls.json
- dataset_reader
- model
- trainer
配置信息被读入为 params 类,并将 “None” 字符串替换为 None
主要文件、继承关系
reader
1、field_reader:base_field_reader --> (根据每个field指定的type) 如text:ernie_text_field_reader
2、dataset_reader:base_dataset_reader --> ernie_onesentclassifition_dataset_reader
tokenizer
tokenizer --> tokenizer_wp
embedding
ERNIE
model
model --> roberta_classifition --> roberta_one_sentclassifition_en
trainer
base_trainer -> glue_task_trainer
vocab字典
roberta_en.vocab.bpe
0. senta核心代码类注册
代码追踪
注册代码:./senta/common\register.py
被注册的相关核心代码类位置:./senta/,如:
- ./senta/data:数据读入的相关,如:各类 field_reader、data_set_reader(都由BasicDataSetReader继承变形而来)
- ./senta/model:各类 model
- ./senta/training:各类 trainer
- ……
这些类会通过module注册到 RegisterSet 类内,方便后期在制作数据读取器、准备模型、制作训练器时取用
1、准备reader信息,用于读取数据
代码追踪
./train.py 中完成了 RegisterSet 的构建及 import之后,开始读取配置信息
其中根据配置信息,前往DataSet类进行构建:./senta/data/data_set.py
其中制作数据集的 data_set_reader,其中包括:
- 每个field的 field_reader
- 读取的配置,如:epoch、batch_size等信息
data_set.py里的详细步骤
-
读取 dataset_reader 里对 train/test/dev/predict 4种情况的data_set_reader 配置信息
-
制作每种情况的 data_set_reader:
2.1. 将该reader的每个field配置读入为 Field 类
2.2. 根据 Field 的 type 在 RegisterSet 里的 field_reader 里找对应的读取器
2.3. 将该reader的epoch等配置信息读入为 ReaderConfig 类
2.4. 获取该reader的命名配置字段
2.5. 根据命名、各field_reader、readerConfig,创建该情况下的 data_set_reader -
如果配置信息里没有predict的配置,则用train的reader配置生成predict_reader来代替
至此完成数据集的读取器准备
使用:
senta.data.data_set_reader.ernie_onesentclassification_dataset_reader_en.OneSentClassifyReaderEn
2、model准备
根据配置信息的
- trainers_num 训练器个数
- num_train_examples 训练集的文本数
- batch_size_train
- epoch_train
计算 max_train_steps = epoch_train * num_train_examples // batch_size_train // trainers_num # 取整除
根据参数,根据type,在 RegisterSet 里寻找对应 model 并进行创建
其中,在统计训练集的文本数 num_train_examples 时,进行训练集文件的读取:./data/en/finetune/SST-2/train\train.tsv,并将读取结果存放在 data_set_reader 里
使用:senta.models.roberta_one_sent_classification_en.RobertaOneSentClassificationEn
3、制作trainer
根据前期制作好的 dataset_reader、model,以及根据获取的 num_train_examples 训练文件个数信息进行配置
根据type,在 RegisterSet 里寻找对应 trainer 并进行创建
步骤:
根据reader信息,分别为 train/test/dev/predict 创建 data_set_reader
使用:senta.training.glue_task_trainer.GlueTaskTrainer
至此完成 reader、model、trainer 的准备