Python OFA:TOWARDS TRAINING ONE GRAPHMODEL FOR ALL CLASSIFICATION TASKS,基于LLMs处理多种图任务的基础模型,模型代码实战

25 篇文章 1 订阅
18 篇文章 1 订阅

1. 文章描述 

本文提出One-for-All (OFA),一种用于构建和训练具有跨不同领域上下文学习能力的基础GNN模型的通用解决方案。OFA有三个主要的独特特性:
(1)OFA使用文本属性图(TAG)将来自不同领域的图数据集集成到一个文本属性图数据集,并利用大模型llm的能力从所有领域联合学习。我们收集了社区中常用的9个图形数据集,其大小、领域和任务类型各不相同(完整列表见表1)。然后,使用人类可读的文本描述图中的所有节点和边,并使用单个LLM将来自不同领域的文本嵌入到相同的嵌入空间。
(2)OFA提出了感兴趣节点(NOI)子图和NOI提示节点,不仅统一了不同类型的图任务,而且提高了基础模型学习图中结构信息的能力。
(3)OFA引入了一种精心设计并广泛适用的图提示范式(graph prompt paradigm,GPP),以特定于任务的方式将提示图插入原始输入图中。提示图中的节点包含有关下游任务的所有相关信息(由文本描述并由与输入图相同的LLM编码器编码)。然后,修改后的
图成为基础图模型的实际输入。因此,该模型可以根据提示图执行不同的任务。图1说明了OFA的管道。经过训练,用户可以用自然文本描述任何图,并应用OFA管道来预测可能未见过的类别。 

本文所提出的OFA是一个通用的图学习框架,使用一个模型来同时解决不同格式和背景的分类任务,类似于LLMs,可以使用相同的模型权重回答本质上不同的问题。图1说明了OFA的管道。
OFA可以分为三个部分。首先,将来自不同领域的图集成到具有相同格式的文本属性图中,允许单个LLM将所有标签嵌入到相同的空间中。
在第二部分中,OFA通过引入NOI子图和NOI提示节点来统一图中不同的任务类型,使图模型能够自动关注与任务相关的信息。
最后,OFA提出了图提示范式(Graph prompt Paradigm, GPP),将任务信息有机地注入到图数据中,从而实现上下文学习。 

构建图基础模型的一个关键挑战是,跨域图数据通常由完全不同的过程生成,并且具有嵌入在不同空间中的节点/边属性。这使得在一个域上训练的图模型几乎不可能泛化到另一个域。然而,尽管不同数据集的属性不同,但几乎所有数据集都可以用人类可解释的语言来描述。例如,在分子图中,节点表示原子,我们可以使用纯文本来描述具有原子特征的节点,包括元素名称、手性等。关键的优势是,通过使用文本来描述节点和边,我们可以应用LLM将不同的图属性编码到相同的空间。因此,引入了文本属性图的概念来系统地集成来自不同领域的图数据。 

图领域中的下游分类任务可以分为以下不同类别:
(1)节点级任务,其中任务是对图中的节点进行分类;
(2)链路级任务,其任务是推理节点对之间的连接;
(3)图级任务,任务是在整个图上进行预测。
然而,不同层次的任务需要不同的过程和方法来处理,这使得构建图的基础模型变得困难。相比之下,语言中的不同下游任务具有相同的自回归生成性质,这使得从llm中使用的下一个token预测任务中学习到的知识一致有利于各种下游任务。那么问题来了:我们能否将不同的图任务统一为一个单一的任务,以促进图领域的训练和知识迁移?

 

llm最吸引人的特性之一是其通过提示进行上下文学习的能力,这允许模型在不同的学习场景中执行各种下游任务,而无需微调。例如,在一个目标是根据标题和摘要预测论文类别的少样本场景中,我们可以为LLMs提供每个类别的k篇论文作为上下文,并指示模型根据所提供的上下文生成预测。然而,对图进行上下文学习的研究仍然相对未知。

本文提出OFA,第一个用于构建用于图学习的基础GNN模型的解决方案。具体而言,文本属性图、兴趣节点和图提示范式赋予了OFA使用单个图模型执行不同领域图相关任务的能力。在不同图域的监督、少样本和零样本学习场景中评估了OFA。OFA在大多数任务上都显示了良好的结果,并在图上显示了作为未来基础模型的巨大潜力。 

 补充一下整体思路

1. 将所有类型的Graph数据集转化为文本图TAGs,即图中所有的节点node和边edge都按照模板用文本描述;

2. 通过LMs(文中选了sentence-transformer,e5,llama2-7b,13b)将不通的graph数据集embedding到相同的向量空间,得到在同一空间的向量表示:节点Xi^{_{}}和边Xij,参考下面的公式7;

3. 构建兴趣节点NOI,其实不是构建的,即关心的那部分节点就是所谓的NOI节点,节点分类就是待分类的一个节点,链接预测就是相连的两个节点,图分类就是一整个子图的所有节点;

4. 构建NOI提示节点,NOI提示节点与NOI节点相互连接,并且构建prompt;

5. 构建NOI提示节点的分类节点,并与NOI提示节点相连接,有几个类别就有几个节点;这就等同于把所有的任务都转化为分类节点的得分分类任务了;如果是二分类,就用sigmoid最后计算出得分即可,公式10,如果是多分类任务,那么计算argmax得分即可,公式11;

6. 损失函数:分类任务的交叉熵损失,公式12。

 

2. 代码实战

环境安装:To install requirement for the project using conda

conda env create -f environment.yml

 在所有数据集上的端到端实验:For joint end-to-end experiments on all collected dataset, run

 python run_cdm.py --override e2e_all_config.yaml

All arguments can be changed by space separated values such as 

python run_cdm.py --override e2e_all_config.yaml num_layers 7 batch_size 512 dropout 0.15 JK none

Users can modify the task_names variable in ./e2e_all_config.yaml to control which datasets are included during training. The length of task_namesd_multiple, and d_min_ratio should be the same. They can also be specified in command line arguments by comma separated values.

 python run_cdm.py task_names cora_link,arxiv d_multiple 1,1 d_min_ratio 1,1

OFA-ind can be specified by

python run_cdm.py task_names cora_link d_multiple 1 d_min_ratio 1

Low resource experiments --To run the few-shot and zero-shot experiments

 python run_cdm.py --override lr_all_config.yaml

Configuration explained 

We define configurations for each task, each task configurations contains several datasets configurations.

Task configurations are stored in ./configs/task_config.yaml. A task usually consists several splits of datasets (not necessarily same datasets). For example, a regular end-to-end Cora node classification task will have the train split of the Cora dataset as the train dataset, the valid split of the Cora dataset as one of the valid dataset, and likewise for the test split. You can also have more validation/test by specifying the train split of the Cora as one of the validation/test datasets. Specifically, a task configuration looks like

arxiv:
  eval_pool_mode: mean
  dataset: arxiv             # dataset name
  eval_set_constructs:
    - stage: train           # a task should have one and only one train stage dataset
      split_name: train
    - stage: valid
      split_name: valid
      dataset: cora          # replace the default dataset for zero-shot tasks
    - stage: valid
      split_name: valid
    - stage: test
      split_name: test
    - stage: test
      split_name: train      # test the train split

Dataset configurations are stored in ./configs/task_config.yaml. A dataset configuration defines how a dataset is constructed. Specifically,

arxiv:
  task_level: e2e_node
  preprocess: null                       # name of the preprocess function defined in task_constructor.py
  construct: ConstructNodeCls            # name of the dataset construction function defined in task_constructor.py
  args: # additional arguments to construct function
    walk_length: null
    single_prompt_edge: True
  eval_metric: acc                       # evaluation metric
  eval_func: classification_func         # evaluation function that process model output and batch to input to evaluator
  eval_mode: max                         # evaluation mode (min/max)
  dataset_name: arxiv                    # name of the OFAPygDataset
  dataset_splitter: ArxivSplitter        # splitting function defined in task_constructor.py
  process_label_func: process_pth_label  # name of process label function that transform original label to the binary labels
  num_classes: 40 

3. Add your own datasets

If you are implementing a dataset like Cora/pubmed/Arxiv, we recommend adding a directory of your data $customized_data $ under data/single_graph/$customized_data$ and implement gen_data.py under the directory, you can use data/Cora/gen_data.py as an example.

After the data is constructed, you need to register you dataset name in here , and implement a splitter like here. If you are doing zero-shot/few-shot tasks, you can constructor zero-shot/few-shot split here too.

Lastly, register a config entry in configs/data_config.yaml. For example, for end-to-end node classification

$data_name$:
  <<: *E2E-node
  dataset_name: $data_name$
  dataset_splitter: $splitter$
  process_label_func: ... # usually processs_pth_label should work
  num_classes: $number of classes$

process_label_func converts the target label to binary label, and transform class embedding if the task is zero-shot/few-shot, where the number of class node is not fixed. A list of avalailable process_label_func is here. It takes in all classes embedding and the correct label. The output is a tuple : (label, class_node_embedding, binary/one-hot label).

If you want more flexibility, then adding customized datasets requires implementation of a customized subclass of OFAPygDataset .A template is here:

class CustomizedOFADataset(OFAPygDataset):
    def gen_data(self):
        """
        Returns a tuple of the following format
        (data, text, extra) 
        data: a list of Pyg Data, if you only have a one large graph, you should still wrap it with the list.
        text: a list of list of texts. e.g. [node_text, edge_text, label_text] this is will be converted to pooled vector representation.
        extra: any extra data (e.g. split information) you want to save.
        """

    def add_text_emb(self, data_list, text_emb):
        """
        This function assigns generated embedding to member variables of the graph

        data_list: data list returned in self.gen_data.
        text_emb: list of torch text tensor corresponding to the returned text in self.gen_data. text_emb[0] = llm_encode(text[0])

        
        """
        data_list[0].node_text_feat = ...     # corresponding node features
        data_list[0].edge_text_feat = ...      # corresponding edge features
        data_list[0].class_node_text_feat = ...      # class node features
        data_list[0].prompt_edge_text_feat = ...     # edge features used in prompt node
        data_list[0].noi_node_text_feat = ...       # noi node features, refer to the paper for the definition
        return self.collate(data_list)

    def get_idx_split(self):
        """
        Return the split information required to split the dataset, this optional, you can further split the dataset in task_constructor.py
        
        """

    def get_task_map(self):
        """
        Because a dataset can have multiple different tasks that requires different prompt/class text embedding. This function returns a task map that maps a task name to the desired text embedding. Specifically, a task map is of the following format.

        prompt_text_map = {task_name1: {"noi_node_text_feat": ["noi_node_text_feat", [$Index in data[0].noi_node_text_feat$]],
                                    "class_node_text_feat": ["class_node_text_feat",
                                                             [$Index in data[0].class_node_text_feat$]],
                                    "prompt_edge_text_feat": ["prompt_edge_text_feat", [$Index in data[0].prompt_edge_text_feat$]]},
                       task_name2: similar to task_name 1}
        Please refer to examples in data/ for details.
        """
        return self.side_data[-1]

    def get_edge_list(self, mode="e2e"):
        """
        Defines how to construct prompt graph
        f2n: noi nodes to noi prompt node
        n2f: noi prompt node to noi nodes
        n2c: noi prompt node to class nodes
        c2n: class nodes to noi prompt node
        For different task/mode you might want to use different prompt graph construction, you can do so by returning a dictionary. For example
        {"f2n":[1,0], "n2c":[2,0]} means you only want f2n and n2c edges, f2n edges have edge type 1, and its text embedding feature is data[0].prompt_edge_text_feat[0]
        """
        if mode == "e2e_link":
            return {"f2n": [1, 0], "n2f": [3, 0], "n2c": [2, 0], "c2n": [4, 0]}
        elif mode == "lr_link":
            return {"f2n": [1, 0], "n2f": [3, 0]}

运行主文件:run_cdm.py 

import argparse
import os
from types import SimpleNamespace

import torch
from pytorch_lightning.loggers import WandbLogger
from torchmetrics import AUROC, Accuracy

import utils
from gp.lightning.data_template import DataModule
from gp.lightning.metric import (
    flat_binary_func,
    EvalKit,
)
from gp.lightning.module_template import ExpConfig
from gp.lightning.training import lightning_fit
from gp.utils.utils import (
    load_yaml,
    combine_dict,
    merge_mod,
    setup_exp,
    set_random_seed,
)
from lightning_model import GraphPredLightning
from models.model import BinGraphModel, BinGraphAttModel
from models.model import PyGRGCNEdge
from task_constructor import UnifiedTaskConstructor
from utils import (
    SentenceEncoder,
    MultiApr,
    MultiAuc,
)


# os.environ["CUDA_LAUNCH_BLOCKING"]="1"

def main(params):
    """
    0. Check GPU setting.
    """
    device, gpu_ids = utils.get_available_devices()
    gpu_size = len(gpu_ids)

    """
    1. Initiate task constructor.
    """
    encoder = SentenceEncoder(params.llm_name, batch_size=params.llm_b_size)

    task_config_lookup = load_yaml(
        os.path.join(os.path.dirname(__file__), "configs", "task_config.yaml")
    )
    data_config_lookup = load_yaml(os.path.join(os.path.dirname(__file__), "configs", "data_config.yaml"))

    if isinstance(params.task_names, str):
        task_names = [a.strip() for a in params.task_names.split(",")]
    else:
        task_names = params.task_names

    tasks = UnifiedTaskConstructor(
        task_names,
        params.load_texts,
        encoder,
        task_config_lookup,
        data_config_lookup,
        batch_size=params.batch_size,
        sample_size=params.train_sample_size,
    )
    val_task_index_lst, val_pool_mode = tasks.construct_exp()

    # remove llm model
    if encoder is not None:
        encoder.flush_model()

    """
    2. Load model 
    """
    out_dim = params.emb_dim + (params.rwpe if params.rwpe is not None else 0)

    gnn = PyGRGCNEdge(
        params.num_layers,
        5,
        out_dim,
        out_dim,
        drop_ratio=params.dropout,
        JK=params.JK,
    )

    bin_model = BinGraphAttModel if params.JK == "none" else BinGraphModel
    model = bin_model(model=gnn, llm_name=params.llm_name, outdim=out_dim, task_dim=1,
                      add_rwpe=params.rwpe, dropout=params.dropout)

    """
    3. Construct datasets and lightning datamodule.
    """

    if hasattr(params, "d_multiple"):
        if isinstance(params.d_multiple, str):
            data_multiple = [float(a) for a in params.d_multiple.split(",")]
        else:
            data_multiple = params.d_multiple
    else:
        data_multiple = [1]

    if hasattr(params, "d_min_ratio"):
        if isinstance(params.d_min_ratio, str):
            min_ratio = [float(a) for a in params.d_min_ratio.split(",")]
        else:
            min_ratio = params.d_min_ratio
    else:
        min_ratio = [1]


    train_data = tasks.make_train_data(data_multiple, min_ratio, data_val_index=val_task_index_lst)

    text_dataset = tasks.make_full_dm_list(
        data_multiple, min_ratio, train_data
    )
    params.datamodule = DataModule(
        text_dataset, gpu_size=gpu_size, num_workers=params.num_workers
    )

    """
    4. Initiate evaluation kit. 
    """
    eval_data = text_dataset["val"] + text_dataset["test"]
    val_state = [dt.state_name for dt in text_dataset["val"]]
    test_state = [dt.state_name for dt in text_dataset["test"]]
    eval_state = val_state + test_state
    eval_metric = [dt.metric for dt in eval_data]
    eval_funcs = [dt.meta_data["eval_func"] for dt in eval_data]
    loss = torch.nn.BCEWithLogitsLoss()
    evlter = []
    for dt in eval_data:
        if dt.metric == "acc":
            evlter.append(Accuracy(task="multiclass", num_classes=dt.classes))
        elif dt.metric == "auc":
            evlter.append(AUROC(task="binary"))
        elif dt.metric == "apr":
            evlter.append(MultiApr(num_labels=dt.classes))
        elif dt.metric == "aucmulti":
            evlter.append(MultiAuc(num_labels=dt.classes))
    metrics = EvalKit(
        eval_metric,
        evlter,
        loss,
        eval_funcs,
        flat_binary_func,
        eval_mode="max",
        exp_prefix="",
        eval_state=eval_state,
        val_monitor_state=val_state[0],
        test_monitor_state=test_state[0],
    )

    """
    5. Initiate optimizer, scheduler and lightning model module.
    """
    optimizer = torch.optim.Adam(
        model.parameters(), lr=params.lr, weight_decay=params.l2
    )
    lr_scheduler = {
        "scheduler": torch.optim.lr_scheduler.StepLR(optimizer, 15, 0.5),
        "interval": "epoch",
        "frequency": 1,
    }

    exp_config = ExpConfig(
        "",
        optimizer,
        dataset_callback=train_data.update,
        lr_scheduler=lr_scheduler,
    )
    exp_config.val_state_name = val_state
    exp_config.test_state_name = test_state

    pred_model = GraphPredLightning(exp_config, model, metrics)

    """
    6. Start training and logging.
    """
    wandb_logger = WandbLogger(
        project=params.log_project,
        name=params.exp_name,
        save_dir=params.exp_dir,
        offline=params.offline_log,
    )


    strategy = "deepspeed_stage_2" if gpu_size > 1 else "auto"
    val_res, test_res = lightning_fit(
        wandb_logger,
        pred_model,
        params.datamodule,
        metrics,
        params.num_epochs,
        strategy=strategy,
        save_model=False,
        load_best=params.load_best,
        reload_freq=1,
        test_rep=params.test_rep,
        val_interval=params.val_interval
    )


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="rl")
    parser.add_argument("--override", type=str)

    parser.add_argument(
        "opts",
        default=[],
        nargs=argparse.REMAINDER,
        help="Modify config options using the command-line",
    )

    params = parser.parse_args()
    configs = []
    configs.append(
        load_yaml(
            os.path.join(
                os.path.dirname(__file__), "configs", "default_config.yaml"
            )
        )
    )

    if params.override is not None:
        override_config = load_yaml(params.override)
        configs.append(override_config)
    # Add for few-shot parameters

    mod_params = combine_dict(*configs)
    mod_params = merge_mod(mod_params, params.opts)
    setup_exp(mod_params)

    params = SimpleNamespace(**mod_params)
    set_random_seed(params.seed)

    torch.set_float32_matmul_precision("high")
    params.log_project = "full_cdm"

    params.exp_name += f"_{params.llm_name}_ofa1"

    print(params)
    main(params)
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

医学小达人

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值