SIMTEG:A FRUSTRATINGLY SIMPLE APPROACH IMPROVES TEXTUAL GRAPH LEARNING,语言模型的PEFT应用于文本图TG

25 篇文章 1 订阅
18 篇文章 1 订阅

1.文章简介

标题《SIMTEG: A FRUSTRATINGLY SIMPLE APPROACH IMPROVES TEXTUAL GRAPH LEARNING》 

这个文章的思路非常简单:

主要做了两阶段的模型训练:

(1)语言模型(LMs)的微调:利用文本图的文本信息和相应的标签对LMs进行微调,这一阶段:针对两个任务(节点分类和连接预测)分别做了两次;

(2)利用微调后的模型生成的特征X进行下一步任务,即GNN的训练,同样针对两个任务(节点分类和连接预测)分别做了两次。

 

 

文章精髓:针对上面的总览图和下面的伪代码一起看:

(1)f_lm:这里指语言模型合并了高效微调;f_mlp:全连接模型;f_gnn:GNN模型;

(2)输入:文本图(邻接矩阵:A,编号id:T),att_msk(M),每一类任务标签(Y)

流程:

任务一,节点分类:

(1)将T和M输入到f_lm,计算X,然后通过全连接得到预测的标签得分logits,然后通过logits和真是标签Y计算交叉熵损失,然后反向传播,更新权重参数;

(2)用训练好的LMs计算出T的向量表示X;

(3)利用上面的X和已有的A以及标签Y训练GNN:利用上一层的A和X做信息聚合得到下一层的特征表示X(GNN基本训练过程),将最后一程的表示X输入全连接计算得到预测标签得分logits,然后计算交叉熵损失,反向传播,更新权重,直到模型达到最优。

任务二,链接预测:

         只是损失函数不一样,(BCEWithLogitsLoss 是一种常用于二分类任务的损失函数,在 PyTorch 中被广泛使用。它结合了 Sigmoid 函数和二元交叉熵(Binary Cross Entropy, BCE)损失函数,使得数值计算更加稳定和高效),其它训练部分一样的。

 

 

 

 

 

2. 代码复现和实战

EnvironmentGet

conda install pytorch torchvision torchaudio pytorch-cuda=11.8 -c pytorch -c nvidia
conda install pyg -c pyg
conda install -c dglteam/label/cu118 dgl # for RevGAT
pip install transformer
pip install optuna # for hp search
pip install deepspeed # recommend using deepspeed if you want to finetune LM by your self 
2.1 Started With An Example 

For all results reported in our paper on OGBN-Arxiv, OGBN-Products, and OGBL-citation2-2.7M, we place the training scripts at ./scripts. Here is an example of reproducing the results of e5-large (LM) + GraphSAGE (GNN).

2.1.1 We should first finetune the language model on specific dataset:

dataset=ogbn-arxiv
model_type=e5-large
suffix=main

# it takes half an hour with 4 A100 (40G)
bash scripts/train.sh --model_type $model_type --dataset $dataset --suffix $suffix \
    --pretrained_repo sentence-transformers/e5-large \
    --lr 5e-5 \
    --weight_decay 1e-5 \
    --batch_size 20 \
    --eval_batch_size 200 \
    --accum_interval 5 \
    --label_smoothing 0.3 \
    --epochs 10 \
    --warmup_ratio 0.15 \
    --lr_scheduler_type linear \
    --use_peft \
    --peft_r 4 \
    --peft_lora_alpha 8 \
    --peft_lora_dropout 0.3 \
    --header_dropout_prob 0.6 \
    --deepspeed ds_config.json # optional, we use stage 2 of deepspeed

all output will be saved at ./out/${dataset}/${model_type}/${suffix} specifically, the generated embs are at ./out/${dataset}/${model_type}/${suffix}/cached_embs/x_embs.pt. Here it is ./out/ogbn-arxiv/e5-large/main/cached_embs/x_embs.pt

 or download the x_embs.pt from our huggingface repo.

or download the x_embs.pt from our huggingface repo.

2.1.2Then we train a GraphSAGE on top of the generated embeddings: 

lm_model_type=e5-large
suffix=main_X_${lm_model_type}
bert_x_dir=out/ogbn-arxiv/e5-large/main/cached_embs/x_embs.pt # should be consistent
bash scripts/single_gpu_train.sh --model_type $model_type --dataset $dataset --suffix $suffix \
    --n_exps 10 \
    --single_gpu 0 \
    --lm_type $lm_model_type \
    --gnn_batch_size 10000 \
    --gnn_eval_batch_size 10000 \
    --gnn_epochs 100 \
    --gnn_dropout 0.4 \
    --gnn_label_smoothing 0.4 \
    --gnn_lr 0.01 \
    --gnn_num_layers 2 \
    --gnn_weight_decay 4e-6 \
    --gnn_eval_interval 1 \
    --use_bert_x \
    --bert_x_dir $bert_x_dir
2.2 Do Ensembling 

Following the above instruction, we can train GNNs on multiple embeddings. To reproduce the results on OGBN-Arxiv, one should train A GNN on original text and TAPE with various LMs (e5-large and all-roberta-large-v1). 

bash scripts/ogbn-arxiv/e5-large/main.sh
bash scripts/ogbn-arxiv-tape/e5-large/main.sh
bash scripts/ogbn-arxiv/roberta-large/main.sh
bash scripts/ogbn-arxiv-tape/roberta-large/main.sh

bash scripts/ogbn-arxiv-tape/revgat/main.sh # contains all training scripts

logits1=out/ogbn-arxiv/revgat/ensemble_X_e5-large/cached_embs
logits2=out/ogbn-arxiv/revgat/ensemble_X_all-roberta-large-v1/cached_embs
logits3=out/ogbn-arxiv-tape/revgat/ensemble_X_e5-large/cached_embs
logits4=out/ogbn-arxiv-tape/revgat/ensemble_X_all-roberta-large-v1/cached_embs
logits5=out/ogbn-arxiv/revgat/ensemble_preds/cached_embs

python compute_ensemble.py \
    --list_logits "${logits1} ${logits2} ${logits3} ${logits4} ${logits5}" \
    --weights 2 2 1 1 1 \
    --start_seed 1
2.3 Misc: HP Search with Optuna 

We search our hyperparameter with optuna. We Implement an easy-to-use search framework for both LMs (distributed) and GNNs (single GPU). One can check out its usage at ./scripts/hp_search and code at ./src/run_optuna, ./run_optuna.py. Below are some tips:

  1. For the HP search of both LMs and GNNs, we save the output of best trial at: out/${dataset}/${model_type}/${suffix}/best/.
  2. One can define their own search space at src/run_optuna/search_space.py
  3. scripts/optuna.sh performs searching with DDP. scripts/single_gpu_optuna.sh performs searching with single GPU training.

 

For example, one can search the HP of a LM on OGBN-Arxiv by running: 

bash scripts/hp_search/peft_lm.sh ogbn-arxiv e5-large
# output of best trial is at: our/ogbn-arxiv/e5-large/optuna_peft/best

or one can search the HP of a GNN on OGBN-Arxiv by running: 

bash scripts/hp_search/gnn.sh ogbn-arxiv e5-large GraphSAGE our/ogbn-arxiv/e5-large/optuna_peft/best/cached_embs/x_embs.pt
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

医学小达人

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值