HiAGM模型源码测试【原始数据集+中文数据集】

论文链接:Hierarchy-Aware Global Model for Hierarchical Text Classification
github代码链接:HiAGM

1. 模型简介

详细内容见阅读笔记

2. 代码结构剖析

2.1. baseDIR

-HiAGM
--config # json配置文件路径(下列文件都是示例文件)
--data # 数据及预处理文件示例
--data_modules # 数据处理脚本文件夹
--helper # 训练辅助脚本文件夹
--models # 模型主体架构文件夹
--train_modules # 训练过程脚本文件夹
--LICENSE
--README.md
--evaluate.py # 训练过程评价脚本(评价参数的计算)
--train.py # 模型训练执行脚本

2.2. configDIR

json配置文件示例。

-HiAGM
--config
# gcn结构编码器、HiAGM-TP框架、rcv1数据集、cpu运行
---gcn-rcv1-v2-cpu.json 
# gcn结构编码器、HiAGM-TP框架、rcv1数据集、gpu运行
---gcn-rcv1-v2.json
# gcn结构编码器、HiAGM-LA框架、rcv1数据集、gpu运行
---gcnla-rcv1-v2.json
# gcn结构编码器、Origin框架、rcv1数据集、gpu运行
---rcv1-v2.json
# TreeLSTM结构编码器、HiAGM-TP框架、rcv1数据集、gpu运行
---tree-rcv1-v2.json
# TreeLSTM结构编码器、HiAGM-LA框架、rcv1数据集、gpu运行
---treela-rcv1-v2.json

2.3. dataDIR

rcv1示例数据,及rcv1,ny,wos数据集的预处理文件。

-HiAGM
--data
# nyt数据预处理脚本
---preprocess_nyt.py
# wos数据预处理脚本
---preprocess_wos.py
# 纽约时报数据文件
---idnewnyt_test.json
---idnewnyt_train.json
---idnewnyt_val.json
---nyt.taxonomy
---nyt_label.vocab
# rcv1数据文件
---rcv1.taxonomy
---rcv1_overall_corpus_train_prob.json
---rcv1_prob.json
---rcv1_test.json
---rcv1_train.json
---rcv1_val.json
---sample_rcv1.taxonomy

数据文件包含以下5个:

# 训练集
dataset_train.json
# 测试集
dataset_test.json
# 验证集
dataset_val.json
# label层级结构
dataset.taxonomy
# label先验概率
dataset_prob.json

2.4. data_modulesDIR

基于准备好的数据,做预处理,生成词表,模型数据加载模块等处理脚本。

-HiAGM
--data_modules
# 校对数据,用于data_loader.py中
---collator.py
# 数据加载文件,加载数据集数据,用于模型训练
---data_loader.py
# 读取数据,用于data_loader.py中
---dataset.py
# 数据预处理,包括去停用词,标点清洗等
---preprocess.py
# 构建vocab文件夹中的label.dict和word.dict文件
---vocab.py

2.5. helperDIR

辅助脚本文件,如日志管理、checkpoint的加载和保存等函数。

-HiAGM
--helper
# 加载json配置文件内容
---configure.py
# 用于计算data文件夹中的dataset.taxonomy和dataset_prob.json文件
---hierarchy_tree_statistic.py
# 日志管理
---logger.py
# checkpoint的加载和保存等函数
---utils.py

2.6. modelsDIR

模型架构文件夹,包含结构编码器、嵌入层、多标签注意力等模型结构脚本。

-HiAGM
--models
# 结构编码器相关脚本文件夹(包括GCNN、结构编码器、tree结构生成、加权treeLSTM)
---structure_model
# 嵌入层
---embedding_layer.py
# 模型主体文件
---model.py
# 多标签注意力
---multi_label_attention.py
# 原始模型主体
---origin.py
# 文本编码器
---text_encoder.py
# 文本特征传播
---text_feature_propagation.py

2.7. train_modulesDIR

-HiAGM
--train_modules
# loss函数
---criterions.py
# 评价指标计算
---evaluation_metrics.py
# 训练器
---trainer.py

3. 评价指标

3.1. 基础概念

预测值=1预测值=0
真实值=1TPFN
真实值=0FPTN
  • TP = True Postive = 真阳性;
  • FP = False Positive = 假阳性;
  • FN = False Negative = 假阴性;
  • TN = True Negative = 真阴性;
  1. 精度

p r e c i s i o n , P P V , p o s i t i v e p r e d i c t i v e v a l u e = T P / ( T P + F P ) precision, PPV, positive predictive value = TP / (TP + FP) precision,PPV,positivepredictivevalue=TP/(TP+FP)

  1. 召回或者敏感度,真阳性率

r e c a l l , s e n s i t i v i t y , T P R , T r u e P o s i t i v e R a t e = T P / ( T P + F N ) recall, sensitivity, TPR, True Positive Rate = TP / (TP + FN) recall,sensitivity,TPR,TruePositiveRate=TP/(TP+FN)

  1. 特异度,或者真阴性率

s p e c i f i c i t y , T N R , T r u e N e g a t i v e R a t e = T N / ( T N + F P ) specificity, TNR, True Negative Rate = TN / (TN + FP) specificity,TNR,TrueNegativeRate=TN/(TN+FP)

  1. F1-值

F 1 _ s c o r e = 2 ∗ p r e c i s i o n ∗ r e c a l l / ( p r e c i s i o n + r e c a l l ) = 2 ∗ T P / ( 2 ∗ T P + F P + F N ) F1\_score = 2 * precision * recall / (precision + recall) =2*TP / (2*TP+FP+FN) F1_score=2precisionrecall/(precision+recall)=2TP/(2TP+FP+FN)

3.2. Micro-F1

不需要区分知识点,直接使用总体样本的准确率和召回率计算F1-score。即 m i c r o _ f 1 = 2 ∗ p r e c i s i o n ∗ r e c a l l / ( p r e c i s i o n + r e c a l l ) micro\_f1 = 2 * precision * recall / (precision + recall) micro_f1=2precisionrecall/(precision+recall)

precision_micro = float(right_total) / predict_total if predict_total > 0 else 0.0
recall_micro = float(right_total) / gold_total
micro_f1 = 2 * precision_micro * recall_micro / (precision_micro + recall_micro) if (precision_micro + recall_micro) > 0 else 0.0

3.3. Macro-F1

macro-f1需要先计算出每一个类别的准确率和召回率及其F1-score,然后通过求均值得到在整个样本上的F1-score

precision_macro = sum([v for _, v in precision_dict.items()]) / len(list(precision_dict.keys()))
recall_macro = sum([v for _, v in recall_dict.items()]) / len(list(precision_dict.keys()))
macro_f1 = sum([v for _, v in fscore_dict.items()]) / len(list(fscore_dict.keys()))

4. 原始数据集测试

4.1. RCV1(样本集)

可跑通样本集

配置文件:gcn-rcv1-v2.json

训练信息:

2021/08/26 09:58:42 - INFO : Building Vocabulary....
2021/08/26 09:58:42 - INFO : Loading Vocabulary from Cached Dictionary...
2021/08/26 09:58:42 - INFO : Vocabulary of token 50002
2021/08/26 09:58:42 - INFO : Vocabulary of label 40
2021/08/26 09:58:43 - INFO : Loading 300-dimension token embedding from pretrained file: USERPATH/HiAGM/glove.6B/glove.6B.300d.txt
2021/08/26 09:59:00 - INFO : Total vocab size of token is 50002.
2021/08/26 09:59:00 - INFO : Pretrained vocab embedding has 49965 / 50002

结果:

2021/08/26 10:02:43 - INFO : Epoch 76 Time Cost 2.3963842391967773 secs.
2021/08/26 10:02:45 - INFO : TRAIN performance at epoch 77 --- Precision: 1.000000, Recall: 0.128205, Micro-F1: 0.227273, Macro-F1: 0.034722, Loss: 0.345776.

2021/08/26 10:02:46 - INFO : DEV performance at epoch 77 --- Precision: 1.000000, Recall: 0.037037, Micro-F1: 0.071429, Macro-F1: 0.016667, Loss: 0.399174.

2021/08/26 10:02:50 - INFO : TEST performance at epoch 27 --- Precision: 1.000000, Recall: 0.178571, Micro-F1: 0.303030, Macro-F1: 0.045000, Loss: 0.425028.
4.1.1. 原始数据集(暂未训练)

4.2. NYT(付费数据集,暂未获取)

4.3. WOS

可跑通

配置文件:gcn-wos-v2-cpu.json

训练信息:

2021/08/23 13:26:43 - INFO : Building Vocabulary....
2021/08/23 13:26:43 - INFO : Loading Vocabulary from Cached Dictionary...
2021/08/23 13:26:43 - INFO : Vocabulary of token 50002
2021/08/23 13:26:43 - INFO : Vocabulary of label 141
2021/08/23 13:26:45 - INFO : Loading 300-dimension token embedding from pretrained file: USERPATH/HiAGM/glove.6B/glove.6B.300d.txt
2021/08/23 13:26:56 - INFO : Total vocab size of token is 50002.
2021/08/23 13:26:56 - INFO : Pretrained vocab embedding has 49965 / 50002

结果:

2021/08/28 07:31:06 - INFO : Epoch 198 Time Cost 2063.6894171237946 secs.
2021/08/28 08:05:01 - INFO : TRAIN performance at epoch 199 --- Precision: 0.997131, Recall: 0.999917, Micro-F1: 0.998522, Macro-F1: 0.989585, Loss: 0.001030.

2021/08/28 08:05:36 - INFO : DEV performance at epoch 199 --- Precision: 0.865440, Recall: 0.822559, Micro-F1: 0.843455, Macro-F1: 0.764416, Loss: 0.018950.

2021/08/28 08:06:24 - INFO : TEST performance at epoch 149 --- Precision: 0.856174, Recall: 0.819091, Micro-F1: 0.837222, Macro-F1: 0.760100, Loss: 0.018396.

5. 中文数据集测试

5.1. 项目数据集构建

5.1.1. embedding文件

没有找到现成的300维中文embedding,暂时先使用sgns.merge.word文件。

5.1.2. 数据文件

数据预处理脚本文件preprocess_ty.py

文件列表:

data/ty_total.json
data/ty_train.json
data/ty_test.json
data/ty_val.json
data/ty.taxnomy

.json文件内容示例:

{"label": ["M131", "M13", "MCAT"], "token": ["The", "German", "central", "bank", "announced", "largerthanexpected", "cut", "main", "money", "market", "interest", "rate", "Thursday", "boosting", "US", "dollar", "triggering", "interest", "rate", "cuts", "European"]}

.taxnomy文件内容示例:

Root	CCAT	ECAT	GCAT	MCAT
CCAT	C12	C13	C15	C18	C22	C24	C31	C33	C41
C15	C151	C152
...
5.1.3. 辅助文件
5.1.3.1. data/ty_prob.json

数据处理脚本文件get_prior_prob.py
用于获取先验概率。

文件内容示例:

{"Root": {"CCAT": 0.37500000000000006, "ECAT": 0.16666666666666669, "GCAT": 0.29166666666666674, "MCAT": 0.16666666666666669}, 
"CCAT": {"C12": 0.08333333333333333, "C13": 0.16666666666666666, "C15": 0.3333333333333333, "C18": 0.08333333333333333, "C22": 0.08333333333333333, "C24": 0.16666666666666666, "C31": 0.0, "C33": 0.08333333333333333, "C41": 0.0},
...
}
5.1.3.2. ty_vocab/word.dict

文件内容示例:

# 词\t词频
said	104
The	65
percent	50
company	30
pct	30
year	29
government	27
5.1.3.3. ty_vocab/label.dict

文件内容示例:

# 标签\t标签下级节点个数
E12	1
ECAT	5
G15	2
G154	1
GCAT	8
GPOL	3
M13	2

5.2. config文件配置

config/gcn-ty-v2.json

文件内容:

{
  "data": {
    "dataset": "ty",
    "data_dir": "data",
    "train_file": "ty_train.json",
    "val_file": "ty_val.json",
    "test_file": "ty_test.json",
    "prob_json": "ty_prob.json",
    "hierarchy": "ty.taxonomy"
  },
  "vocabulary": {
    "dir": "ty_vocab",
    "vocab_dict": "word.dict",
    "max_token_vocab": 60000,
    "label_dict": "label.dict"
  },
  "embedding": {
    "token": {
      "dimension": 300,
      "type": "pretrain",
      "pretrained_file": "USERPATH/HiAGM/sgns/sgns.merge.word",
      "dropout": 0.5,
      "init_type": "uniform"
    },
    "label": {
      "dimension": 300,
      "type": "random",
      "dropout": 0.5,
      "init_type": "kaiming_uniform"
    }
  },
  "text_encoder": {
    "max_length": 256,
    "RNN": {
      "bidirectional": true,
      "num_layers": 1,
      "type": "GRU",
      "hidden_dimension": 64,
      "dropout": 0.1
    },
    "CNN": {
      "kernel_size": [2, 3, 4],
      "num_kernel": 100
    },
    "topK_max_pooling": 1
  },
  "structure_encoder": {
    "type": "GCN",
    "node": {
      "type": "text",
      "dimension": 300,
      "dropout": 0.05
    }
  },
  "model": {
    "type": "HiAGM-TP",
    "linear_transformation": {
      "text_dimension": 300,
      "node_dimension": 300,
      "dropout": 0.5
    },
    "classifier": {
      "num_layer": 1,
      "dropout": 0.5
    }
  },
  "train": {
    "optimizer": {
      "type": "Adam",
      "learning_rate": 0.0001,
      "lr_decay": 1.0,
      "lr_patience": 5,
      "early_stopping": 50
    },
    "batch_size": 64,
    "start_epoch": 0,
    "end_epoch": 250,
    "loss": {
      "classification": "BCEWithLogitsLoss",
      "recursive_regularization": {
        "flag": true,
        "penalty": 0.000001
      }
    },
    "device_setting": {
      "device": "cuda",
      "visible_device_list": "0",
      "num_workers": 10
    },
    "checkpoint": {
      "dir": "ty_hiagm_tp_checkpoint",
      "max_number": 10,
      "save_best": ["Macro_F1", "Micro_F1"]
    }
  },
  "eval": {
    "batch_size": 512,
    "threshold": 0.5
  },
  "test": {
    "best_checkpoint": "best_micro_HiAGM-TP",
    "batch_size": 512
  },
  "log": {
    "level": "info",
    "filename": "gcn-ty-v2.log"
  }
}

6. 模型预测

由于模型本身不具备predict脚本,项目原因添加了predict模块。
以快速实现功能为主,脚本/框架比较粗糙,后期再修改。

6.1. 数据集准备

使用data/xlsx2json.py生成ty_predict.json文件。label给空list就行。

文件内容示例:

{"token": ["The", "German", "central", "bank", "announced", "largerthanexpected", "cut", "main", "money", "market", "interest", "rate", "Thursday", "boosting", "US", "dollar", "triggering", "interest", "rate", "cuts", "European"], "label": []}
{"token": ["The", "German", "central", "bank", "announced", "largerthanexpected", "cut", "main", "money", "market", "interest", "rate"], "label": []}
...

6.2. 数据加载文件修改

项目利益相关,不提供本小节代码。

6.2.1. data_modules/dataset.py

文件中ClassificationDataset类的__init__函数中self.corpus_files添加"PRED"。

6.2.2. data_modules/data_loader.py

文件中data_loaders函数中添加pred_loader部分。

6.3. config配置文件修改

config/gcn-ty-v2.json文件中,data部分添加pred_file

文件内容:

{
  "data": {
    "dataset": "ty",
    "data_dir": "data",
    "train_file": "ty_train.json",
    "val_file": "ty_val.json",
    "test_file": "ty_test.json",
    "pred_file": "ty_pred.json",
    "prob_json": "ty_prob.json",
    "hierarchy": "ty.taxonomy"
  }
}

6.4. 预测脚本修改

项目利益相关,不提供本小节代码。

6.4.1. predict.py

这里脚本根据train.py进行的改动,仅快速实现了预测功能,还需要修改优化。

6.4.2. train_modules/trainer.py

文件中Trainer类的run函数中添加"PRED"条件,并新增函数pred

6.4.3. train_modules/evaluation_metrics.py

文件中evaluate函数中添加pred_label,并返回。

7. 预测记录

由于没有进行任何优化和数据清洗处理,只是跑了流程,数据标签有4层,所以当前效果比较差。

  • 4
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 12
    评论
评论 12
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值