项目重要文件的目录树如下:
spert
│ args.py # 各种参数的设置
│ config_reader.py # 读取并处理config文件
│ spert.py # 程序入口
│ __init__.py # 空文件,用于构成package
├─bert-base-chinese
│ bert_config.json
│ config.json
│ pytorch_model.bin
│ vocab.txt
├─configs # config文件
│ example_eval.conf
│ example_predict.conf
│ example_train.conf
├─data # 数据集(已处理好格式)
│ ├─datasets
│ │ ├─ade
│ │ ├─conll04
│ │ └─scierc
│ ├─log
│ └─save
└─spert
entities.py
evaluator.py
input_reader.py # 读取输入,定义BaseInputReader和JsonInputReader类
loss.py # 定义SpERTLoss类,类函数compute()计算模型loss
models.py # 定义SpERT类(继承transformers库中的BertPreTrainedModel类)
opt.py # optional
prediction.py
sampling.py # 生成正负样本
spert_trainer.py # 定义SpERTTrainer类(继承BaseTrainer类)
trainer.py # 定义BaseTrainer类
util.py # 各种小函数
__init__.py
1.程序入口
训练模型使用如下命令:
python ./spert.py train --config configs/example_train.conf
故程序的入口是spert.py
:
if __name__ == '__main__':
arg_parser = argparse.ArgumentParser(add_help=False)
arg_