1.文章简介
标题《SIMTEG: A FRUSTRATINGLY SIMPLE APPROACH IMPROVES TEXTUAL GRAPH LEARNING》
这个文章的思路非常简单:
主要做了两阶段的模型训练:
(1)语言模型(LMs)的微调:利用文本图的文本信息和相应的标签对LMs进行微调,这一阶段:针对两个任务(节点分类和连接预测)分别做了两次;
(2)利用微调后的模型生成的特征X进行下一步任务,即GNN的训练,同样针对两个任务(节点分类和连接预测)分别做了两次。
文章精髓:针对上面的总览图和下面的伪代码一起看:
(1)f_lm:这里指语言模型合并了高效微调;f_mlp:全连接模型;f_gnn:GNN模型;
(2)输入:文本图(邻接矩阵:A,编号id:T),att_msk(M),每一类任务标签(Y)
流程:
任务一,节点分类:
(1)将T和M输入到f_lm,计算X,然后通过全连接得到预测的标签得分logits,然后通过logits和真是标签Y计算交叉熵损失,然后反向传播,更新权重参数;
(2)用训练好的LMs计算出T的向量表示X;
(3)利用上面的X和已有的A以及标签Y训练GNN:利用上一层的A和X做信息聚合得到下一层的特征表示X(GNN基本训练过程),将最后一程的表示X输入全连接计算得到预测标签得分logits,然后计算交叉熵损失,反向传播,更新权重,直到模型达到最优。
任务二,链接预测:
只是损失函数不一样,(BCEWithLogitsLoss 是一种常用于二分类任务的损失函数,在 PyTorch 中被广泛使用。它结合了 Sigmoid 函数和二元交叉熵(Binary Cross Entropy, BCE)损失函数,使得数值计算更加稳定和高效),其它训练部分一样的。
2. 代码复现和实战
EnvironmentGet
conda install pytorch torchvision torchaudio pytorch-cuda=11.8 -c pytorch -c nvidia
conda install pyg -c pyg
conda install -c dglteam/label/cu118 dgl # for RevGAT
pip install transformer
pip install optuna # for hp search
pip install deepspeed # recommend using deepspeed if you want to finetune LM by your self
2.1 Started With An Example
For all results reported in our paper on OGBN-Arxiv, OGBN-Products, and OGBL-citation2-2.7M, we place the training scripts at
./scripts
. Here is an example of reproducing the results ofe5-large
(LM) +GraphSAGE
(GNN).
2.1.1 We should first finetune the language model on specific dataset:
dataset=ogbn-arxiv
model_type=e5-large
suffix=main
# it takes half an hour with 4 A100 (40G)
bash scripts/train.sh --model_type $model_type --dataset $dataset --suffix $suffix \
--pretrained_repo sentence-transformers/e5-large \
--lr 5e-5 \
--weight_decay 1e-5 \
--batch_size 20 \
--eval_batch_size 200 \
--accum_interval 5 \
--label_smoothing 0.3 \
--epochs 10 \
--warmup_ratio 0.15 \
--lr_scheduler_type linear \
--use_peft \
--peft_r 4 \
--peft_lora_alpha 8 \
--peft_lora_dropout 0.3 \
--header_dropout_prob 0.6 \
--deepspeed ds_config.json # optional, we use stage 2 of deepspeed
all output will be saved at
./out/${dataset}/${model_type}/${suffix}
specifically, the generated embs are at./out/${dataset}/${model_type}/${suffix}/cached_embs/x_embs.pt
. Here it is./out/ogbn-arxiv/e5-large/main/cached_embs/x_embs.pt
.
or download the x_embs.pt
from our huggingface repo.
or download the x_embs.pt from our huggingface repo.
2.1.2Then we train a GraphSAGE on top of the generated embeddings:
lm_model_type=e5-large
suffix=main_X_${lm_model_type}
bert_x_dir=out/ogbn-arxiv/e5-large/main/cached_embs/x_embs.pt # should be consistent
bash scripts/single_gpu_train.sh --model_type $model_type --dataset $dataset --suffix $suffix \
--n_exps 10 \
--single_gpu 0 \
--lm_type $lm_model_type \
--gnn_batch_size 10000 \
--gnn_eval_batch_size 10000 \
--gnn_epochs 100 \
--gnn_dropout 0.4 \
--gnn_label_smoothing 0.4 \
--gnn_lr 0.01 \
--gnn_num_layers 2 \
--gnn_weight_decay 4e-6 \
--gnn_eval_interval 1 \
--use_bert_x \
--bert_x_dir $bert_x_dir
2.2 Do Ensembling
Following the above instruction, we can train GNNs on multiple embeddings. To reproduce the results on OGBN-Arxiv, one should train A GNN on original text and TAPE with various LMs (e5-large and all-roberta-large-v1).
bash scripts/ogbn-arxiv/e5-large/main.sh
bash scripts/ogbn-arxiv-tape/e5-large/main.sh
bash scripts/ogbn-arxiv/roberta-large/main.sh
bash scripts/ogbn-arxiv-tape/roberta-large/main.sh
bash scripts/ogbn-arxiv-tape/revgat/main.sh # contains all training scripts
logits1=out/ogbn-arxiv/revgat/ensemble_X_e5-large/cached_embs
logits2=out/ogbn-arxiv/revgat/ensemble_X_all-roberta-large-v1/cached_embs
logits3=out/ogbn-arxiv-tape/revgat/ensemble_X_e5-large/cached_embs
logits4=out/ogbn-arxiv-tape/revgat/ensemble_X_all-roberta-large-v1/cached_embs
logits5=out/ogbn-arxiv/revgat/ensemble_preds/cached_embs
python compute_ensemble.py \
--list_logits "${logits1} ${logits2} ${logits3} ${logits4} ${logits5}" \
--weights 2 2 1 1 1 \
--start_seed 1
2.3 Misc: HP Search with Optuna
We search our hyperparameter with optuna. We Implement an easy-to-use search framework for both LMs (distributed) and GNNs (single GPU). One can check out its usage at
./scripts/hp_search
and code at./src/run_optuna, ./run_optuna.py
. Below are some tips:
- For the HP search of both LMs and GNNs, we save the output of best trial at:
out/${dataset}/${model_type}/${suffix}/best/
.- One can define their own search space at
src/run_optuna/search_space.py
scripts/optuna.sh
performs searching with DDP.scripts/single_gpu_optuna.sh
performs searching with single GPU training.
For example, one can search the HP of a LM on OGBN-Arxiv by running:
bash scripts/hp_search/peft_lm.sh ogbn-arxiv e5-large
# output of best trial is at: our/ogbn-arxiv/e5-large/optuna_peft/best
or one can search the HP of a GNN on OGBN-Arxiv by running:
bash scripts/hp_search/gnn.sh ogbn-arxiv e5-large GraphSAGE our/ogbn-arxiv/e5-large/optuna_peft/best/cached_embs/x_embs.pt