论文链接:Hierarchy-Aware Global Model for Hierarchical Text Classification
github代码链接:HiAGM
HiAGM模型源码测试
1. 模型简介
详细内容见阅读笔记
2. 代码结构剖析
2.1. baseDIR
-HiAGM
--config # json配置文件路径(下列文件都是示例文件)
--data # 数据及预处理文件示例
--data_modules # 数据处理脚本文件夹
--helper # 训练辅助脚本文件夹
--models # 模型主体架构文件夹
--train_modules # 训练过程脚本文件夹
--LICENSE
--README.md
--evaluate.py # 训练过程评价脚本(评价参数的计算)
--train.py # 模型训练执行脚本
2.2. configDIR
json配置文件示例。
-HiAGM
--config
# gcn结构编码器、HiAGM-TP框架、rcv1数据集、cpu运行
---gcn-rcv1-v2-cpu.json
# gcn结构编码器、HiAGM-TP框架、rcv1数据集、gpu运行
---gcn-rcv1-v2.json
# gcn结构编码器、HiAGM-LA框架、rcv1数据集、gpu运行
---gcnla-rcv1-v2.json
# gcn结构编码器、Origin框架、rcv1数据集、gpu运行
---rcv1-v2.json
# TreeLSTM结构编码器、HiAGM-TP框架、rcv1数据集、gpu运行
---tree-rcv1-v2.json
# TreeLSTM结构编码器、HiAGM-LA框架、rcv1数据集、gpu运行
---treela-rcv1-v2.json
2.3. dataDIR
rcv1示例数据,及rcv1,ny,wos数据集的预处理文件。
-HiAGM
--data
# nyt数据预处理脚本
---preprocess_nyt.py
# wos数据预处理脚本
---preprocess_wos.py
# 纽约时报数据文件
---idnewnyt_test.json
---idnewnyt_train.json
---idnewnyt_val.json
---nyt.taxonomy
---nyt_label.vocab
# rcv1数据文件
---rcv1.taxonomy
---rcv1_overall_corpus_train_prob.json
---rcv1_prob.json
---rcv1_test.json
---rcv1_train.json
---rcv1_val.json
---sample_rcv1.taxonomy
数据文件包含以下5个:
# 训练集
dataset_train.json
# 测试集
dataset_test.json
# 验证集
dataset_val.json
# label层级结构
dataset.taxonomy
# label先验概率
dataset_prob.json
2.4. data_modulesDIR
基于准备好的数据,做预处理,生成词表,模型数据加载模块等处理脚本。
-HiAGM
--data_modules
# 校对数据,用于data_loader.py中
---collator.py
# 数据加载文件,加载数据集数据,用于模型训练
---data_loader.py
# 读取数据,用于data_loader.py中
---dataset.py
# 数据预处理,包括去停用词,标点清洗等
---preprocess.py
# 构建vocab文件夹中的label.dict和word.dict文件
---vocab.py
2.5. helperDIR
辅助脚本文件,如日志管理、checkpoint的加载和保存等函数。
-HiAGM
--helper
# 加载json配置文件内容
---configure.py
# 用于计算data文件夹中的dataset.taxonomy和dataset_prob.json文件
---hierarchy_tree_statistic.py
# 日志管理
---logger.py
# checkpoint的加载和保存等函数
---utils.py
2.6. modelsDIR
模型架构文件夹,包含结构编码器、嵌入层、多标签注意力等模型结构脚本。
-HiAGM
--models
# 结构编码器相关脚本文件夹(包括GCNN、结构编码器、tree结构生成、加权treeLSTM)
---structure_model
# 嵌入层
---embedding_layer.py
# 模型主体文件
---model.py
# 多标签注意力
---multi_label_attention.py
# 原始模型主体
---origin.py
# 文本编码器
---text_encoder.py
# 文本特征传播
---text_feature_propagation.py
2.7. train_modulesDIR
-HiAGM
--train_modules
# loss函数
---criterions.py
# 评价指标计算
---evaluation_metrics.py
# 训练器
---trainer.py
3. 评价指标
3.1. 基础概念
预测值=1 | 预测值=0 | |
---|---|---|
真实值=1 | TP | FN |
真实值=0 | FP | TN |
- TP = True Postive = 真阳性;
- FP = False Positive = 假阳性;
- FN = False Negative = 假阴性;
- TN = True Negative = 真阴性;
- 精度
p r e c i s i o n , P P V , p o s i t i v e p r e d i c t i v e v a l u e = T P / ( T P + F P ) precision, PPV, positive predictive value = TP / (TP + FP) precision,PPV,positivepredictivevalue=TP/(TP+FP)
- 召回或者敏感度,真阳性率
r e c a l l , s e n s i t i v i t y , T P R , T r u e P o s i t i v e R a t e = T P / ( T P + F N ) recall, sensitivity, TPR, True Positive Rate = TP / (TP + FN) recall,sensitivity,TPR,TruePositiveRate=TP/(TP+FN)
- 特异度,或者真阴性率
s p e c i f i c i t y , T N R , T r u e N e g a t i v e R a t e = T N / ( T N + F P ) specificity, TNR, True Negative Rate = TN / (TN + FP) specificity,TNR,TrueNegativeRate=TN/(TN+FP)
- F1-值
F 1 _ s c o r e = 2 ∗ p r e c i s i o n ∗ r e c a l l / ( p r e c i s i o n + r e c a l l ) = 2 ∗ T P / ( 2 ∗ T P + F P + F N ) F1\_score = 2 * precision * recall / (precision + recall) =2*TP / (2*TP+FP+FN) F1_score=2∗precision∗recall/(precision+recall)=2∗TP/(2∗TP+FP+FN)
3.2. Micro-F1
不需要区分知识点,直接使用总体样本的准确率和召回率计算F1-score
。即
m
i
c
r
o
_
f
1
=
2
∗
p
r
e
c
i
s
i
o
n
∗
r
e
c
a
l
l
/
(
p
r
e
c
i
s
i
o
n
+
r
e
c
a
l
l
)
micro\_f1 = 2 * precision * recall / (precision + recall)
micro_f1=2∗precision∗recall/(precision+recall)
precision_micro = float(right_total) / predict_total if predict_total > 0 else 0.0
recall_micro = float(right_total) / gold_total
micro_f1 = 2 * precision_micro * recall_micro / (precision_micro + recall_micro) if (precision_micro + recall_micro) > 0 else 0.0
3.3. Macro-F1
macro-f1需要先计算出每一个类别的准确率和召回率及其F1-score
,然后通过求均值得到在整个样本上的F1-score
。
precision_macro = sum([v for _, v in precision_dict.items()]) / len(list(precision_dict.keys()))
recall_macro = sum([v for _, v in recall_dict.items()]) / len(list(precision_dict.keys()))
macro_f1 = sum([v for _, v in fscore_dict.items()]) / len(list(fscore_dict.keys()))
4. 原始数据集测试
4.1. RCV1(样本集)
可跑通样本集
配置文件:gcn-rcv1-v2.json
训练信息:
2021/08/26 09:58:42 - INFO : Building Vocabulary....
2021/08/26 09:58:42 - INFO : Loading Vocabulary from Cached Dictionary...
2021/08/26 09:58:42 - INFO : Vocabulary of token 50002
2021/08/26 09:58:42 - INFO : Vocabulary of label 40
2021/08/26 09:58:43 - INFO : Loading 300-dimension token embedding from pretrained file: USERPATH/HiAGM/glove.6B/glove.6B.300d.txt
2021/08/26 09:59:00 - INFO : Total vocab size of token is 50002.
2021/08/26 09:59:00 - INFO : Pretrained vocab embedding has 49965 / 50002
结果:
2021/08/26 10:02:43 - INFO : Epoch 76 Time Cost 2.3963842391967773 secs.
2021/08/26 10:02:45 - INFO : TRAIN performance at epoch 77 --- Precision: 1.000000, Recall: 0.128205, Micro-F1: 0.227273, Macro-F1: 0.034722, Loss: 0.345776.
2021/08/26 10:02:46 - INFO : DEV performance at epoch 77 --- Precision: 1.000000, Recall: 0.037037, Micro-F1: 0.071429, Macro-F1: 0.016667, Loss: 0.399174.
2021/08/26 10:02:50 - INFO : TEST performance at epoch 27 --- Precision: 1.000000, Recall: 0.178571, Micro-F1: 0.303030, Macro-F1: 0.045000, Loss: 0.425028.
4.1.1. 原始数据集(暂未训练)
4.2. NYT(付费数据集,暂未获取)
4.3. WOS
可跑通
配置文件:gcn-wos-v2-cpu.json
训练信息:
2021/08/23 13:26:43 - INFO : Building Vocabulary....
2021/08/23 13:26:43 - INFO : Loading Vocabulary from Cached Dictionary...
2021/08/23 13:26:43 - INFO : Vocabulary of token 50002
2021/08/23 13:26:43 - INFO : Vocabulary of label 141
2021/08/23 13:26:45 - INFO : Loading 300-dimension token embedding from pretrained file: USERPATH/HiAGM/glove.6B/glove.6B.300d.txt
2021/08/23 13:26:56 - INFO : Total vocab size of token is 50002.
2021/08/23 13:26:56 - INFO : Pretrained vocab embedding has 49965 / 50002
结果:
2021/08/28 07:31:06 - INFO : Epoch 198 Time Cost 2063.6894171237946 secs.
2021/08/28 08:05:01 - INFO : TRAIN performance at epoch 199 --- Precision: 0.997131, Recall: 0.999917, Micro-F1: 0.998522, Macro-F1: 0.989585, Loss: 0.001030.
2021/08/28 08:05:36 - INFO : DEV performance at epoch 199 --- Precision: 0.865440, Recall: 0.822559, Micro-F1: 0.843455, Macro-F1: 0.764416, Loss: 0.018950.
2021/08/28 08:06:24 - INFO : TEST performance at epoch 149 --- Precision: 0.856174, Recall: 0.819091, Micro-F1: 0.837222, Macro-F1: 0.760100, Loss: 0.018396.
5. 中文数据集测试
5.1. 项目数据集构建
5.1.1. embedding文件
没有找到现成的300维中文embedding,暂时先使用
sgns.merge.word
文件。
5.1.2. 数据文件
数据预处理脚本文件
preprocess_ty.py
。
文件列表:
data/ty_total.json
data/ty_train.json
data/ty_test.json
data/ty_val.json
data/ty.taxnomy
.json
文件内容示例:
{"label": ["M131", "M13", "MCAT"], "token": ["The", "German", "central", "bank", "announced", "largerthanexpected", "cut", "main", "money", "market", "interest", "rate", "Thursday", "boosting", "US", "dollar", "triggering", "interest", "rate", "cuts", "European"]}
.taxnomy
文件内容示例:
Root CCAT ECAT GCAT MCAT
CCAT C12 C13 C15 C18 C22 C24 C31 C33 C41
C15 C151 C152
...
5.1.3. 辅助文件
5.1.3.1. data/ty_prob.json
数据处理脚本文件
get_prior_prob.py
。
用于获取先验概率。
文件内容示例:
{"Root": {"CCAT": 0.37500000000000006, "ECAT": 0.16666666666666669, "GCAT": 0.29166666666666674, "MCAT": 0.16666666666666669},
"CCAT": {"C12": 0.08333333333333333, "C13": 0.16666666666666666, "C15": 0.3333333333333333, "C18": 0.08333333333333333, "C22": 0.08333333333333333, "C24": 0.16666666666666666, "C31": 0.0, "C33": 0.08333333333333333, "C41": 0.0},
...
}
5.1.3.2. ty_vocab/word.dict
文件内容示例:
# 词\t词频
said 104
The 65
percent 50
company 30
pct 30
year 29
government 27
5.1.3.3. ty_vocab/label.dict
文件内容示例:
# 标签\t标签下级节点个数
E12 1
ECAT 5
G15 2
G154 1
GCAT 8
GPOL 3
M13 2
5.2. config文件配置
config/gcn-ty-v2.json
文件内容:
{
"data": {
"dataset": "ty",
"data_dir": "data",
"train_file": "ty_train.json",
"val_file": "ty_val.json",
"test_file": "ty_test.json",
"prob_json": "ty_prob.json",
"hierarchy": "ty.taxonomy"
},
"vocabulary": {
"dir": "ty_vocab",
"vocab_dict": "word.dict",
"max_token_vocab": 60000,
"label_dict": "label.dict"
},
"embedding": {
"token": {
"dimension": 300,
"type": "pretrain",
"pretrained_file": "USERPATH/HiAGM/sgns/sgns.merge.word",
"dropout": 0.5,
"init_type": "uniform"
},
"label": {
"dimension": 300,
"type": "random",
"dropout": 0.5,
"init_type": "kaiming_uniform"
}
},
"text_encoder": {
"max_length": 256,
"RNN": {
"bidirectional": true,
"num_layers": 1,
"type": "GRU",
"hidden_dimension": 64,
"dropout": 0.1
},
"CNN": {
"kernel_size": [2, 3, 4],
"num_kernel": 100
},
"topK_max_pooling": 1
},
"structure_encoder": {
"type": "GCN",
"node": {
"type": "text",
"dimension": 300,
"dropout": 0.05
}
},
"model": {
"type": "HiAGM-TP",
"linear_transformation": {
"text_dimension": 300,
"node_dimension": 300,
"dropout": 0.5
},
"classifier": {
"num_layer": 1,
"dropout": 0.5
}
},
"train": {
"optimizer": {
"type": "Adam",
"learning_rate": 0.0001,
"lr_decay": 1.0,
"lr_patience": 5,
"early_stopping": 50
},
"batch_size": 64,
"start_epoch": 0,
"end_epoch": 250,
"loss": {
"classification": "BCEWithLogitsLoss",
"recursive_regularization": {
"flag": true,
"penalty": 0.000001
}
},
"device_setting": {
"device": "cuda",
"visible_device_list": "0",
"num_workers": 10
},
"checkpoint": {
"dir": "ty_hiagm_tp_checkpoint",
"max_number": 10,
"save_best": ["Macro_F1", "Micro_F1"]
}
},
"eval": {
"batch_size": 512,
"threshold": 0.5
},
"test": {
"best_checkpoint": "best_micro_HiAGM-TP",
"batch_size": 512
},
"log": {
"level": "info",
"filename": "gcn-ty-v2.log"
}
}
6. 模型预测
由于模型本身不具备
predict
脚本,项目原因添加了predict
模块。
以快速实现功能为主,脚本/框架比较粗糙,后期再修改。
6.1. 数据集准备
使用
data/xlsx2json.py
生成ty_predict.json
文件。label给空list就行。
文件内容示例:
{"token": ["The", "German", "central", "bank", "announced", "largerthanexpected", "cut", "main", "money", "market", "interest", "rate", "Thursday", "boosting", "US", "dollar", "triggering", "interest", "rate", "cuts", "European"], "label": []}
{"token": ["The", "German", "central", "bank", "announced", "largerthanexpected", "cut", "main", "money", "market", "interest", "rate"], "label": []}
...
6.2. 数据加载文件修改
项目利益相关,不提供本小节代码。
6.2.1. data_modules/dataset.py
文件中ClassificationDataset
类的__init__
函数中self.corpus_files
添加"PRED"。
6.2.2. data_modules/data_loader.py
文件中data_loaders
函数中添加pred_loader
部分。
6.3. config配置文件修改
config/gcn-ty-v2.json
文件中,data
部分添加pred_file
。
文件内容:
{
"data": {
"dataset": "ty",
"data_dir": "data",
"train_file": "ty_train.json",
"val_file": "ty_val.json",
"test_file": "ty_test.json",
"pred_file": "ty_pred.json",
"prob_json": "ty_prob.json",
"hierarchy": "ty.taxonomy"
}
}
6.4. 预测脚本修改
项目利益相关,不提供本小节代码。
6.4.1. predict.py
这里脚本根据train.py
进行的改动,仅快速实现了预测功能,还需要修改优化。
6.4.2. train_modules/trainer.py
文件中Trainer
类的run
函数中添加"PRED"条件,并新增函数pred
。
6.4.3. train_modules/evaluation_metrics.py
文件中evaluate
函数中添加pred_label
,并返回。
7. 预测记录
由于没有进行任何优化和数据清洗处理,只是跑了流程,数据标签有4层,所以当前效果比较差。