多任务学习模型-AdaTT

        AdaTT(Adaptive Task-to-Task Fusion Network)是一种多任务学习模型,由Meta提出。它旨在通过自适应融合机制有效地模拟复杂的任务关系,并促进特定任务和共享知识的联合学习。

1.AdaTT模型

AdaTT模型的核心特点:

  1. 自适应融合机制:AdaTT利用自适应融合机制来处理不同任务之间的关系,允许模型动态地调整不同任务之间的信息交流和知识共享。
  2. 多层次融合:AdaTT在多个层次上进行任务特定的和共享的知识融合,这有助于模型捕捉不同粒度的特征表示。
  3. 残差连接:通过残差连接,AdaTT可以更有效地训练深层网络,允许梯度直接流向前面的层,从而缓解梯度消失问题。
  4. 门控机制:AdaTT使用门控机制来控制不同任务之间的信息流动,使得模型可以自适应地选择对当前任务最有用的信息。

AdaTT模型的结构:

  • 任务特定融合单元:每个任务都有其专用的融合单元,这些单元负责处理与特定任务相关的信息。
  • 可选共享融合单元:除了任务特定的融合单元外,AdaTT还包含可选的共享融合单元,这些单元在所有任务间共享,用于捕获任务间的共有特征。
  • 残差连接和门控网络:残差连接允许模型在融合单元之间进行有效的信息传递,而门控网络则用于调节信息的流动。

AdaTT与现有模型的比较:

  • 与MMoE和PLE等模型相比,AdaTT通过引入自适应融合机制,能够更灵活地处理任务间的关系。
  • AdaTT在模型结构上与PLE相似,但在PLE的基础上做了改进,增加了一个简单的线性映射专家网络,并与共享专家网络的输出进行加和,从而提高了模型的自适应能力。

AdaTT的效果:

  • 根据描述,AdaTT在不同任务相关度的实验中均表现优异,无论是任务相关度低、高还是多样化的任务集合,AdaTT都能取得最佳效果。
  • AdaTT的Normalized Entropy(NE)最低,表明其在多任务学习中的效果最好。

AdaTT的应用:

  • AdaTT特别适用于推荐系统中的多任务学习问题,可以同时对用户的兴趣、购买意愿等多个目标进行建模,从而提高推荐的准确性和转化率。

总结:

        AdaTT是一种先进的多任务学习模型,通过自适应融合机制和多层次的融合策略,有效地处理了不同任务之间的关系,并促进了知识的共享和特定任务的学习。AdaTT在推荐系统等多个领域的应用展现出了其强大的性能和广泛的应用潜力。

2.AdaTT代码

AdaTT(Adaptive Task-to-Task Fusion Network)的代码实现可以通过GitHub上的官方代码库来了解。以下是AdaTT代码实现的详细步骤和组件,引用自官方代码库 :

AdaTT代码库概览

AdaTT的官方代码库提供了一个PyTorch库,专注于多任务学习,特别是针对论文 "AdaTT: Adaptive Task-to-Task Fusion Network for Multitask Learning in Recommendations" (KDD'23) 中评估的模型。

# Copyright (c) Meta Platforms, Inc. and affiliates.

import logging
from dataclasses import dataclass, field
from math import sqrt
from typing import List, Optional, Union

import torch
import torch.nn as nn


logger: logging.Logger = logging.getLogger(__name__)


@dataclass
class MtlConfigs:
    mtl_model: str = "att_sp"  # consider using enum
    num_task_experts: int = 1
    num_shared_experts: int = 1
    expert_out_dims: List[List[int]] = field(default_factory=list)
    self_exp_res_connect: bool = False
    expert_archs: Optional[List[List[int]]] = None
    gate_archs: Optional[List[List[int]]] = None
    num_experts: Optional[int] = None


@dataclass(frozen=True)
class ArchInputs:
    num_task: int = 3

    task_mlp: List[int] = field(default_factory=list)

    mtl_configs: Optional[MtlConfigs] = field(default=None)

    # Parameters related to activation function
    activation_type: str = "RELU"


class AdaTTSp(nn.Module):
    """
    paper title: "AdaTT: Adaptive Task-to-Task Fusion Network for Multitask Learning in Recommendations"
    paper link: https://doi.org/10.1145/3580305.3599769
    Call Args:
        inputs: inputs is a tensor of dimension
            [batch_size, self.num_tasks, self.input_dim].
            Experts in the same module share the same input.
        outputs dimensions: [B, T, D_out]

    Example::
        AdaTTSp(
            input_dim=256,
            expert_out_dims=[[128, 128]],
            num_tasks=8,
            num_task_experts=2,
            self_exp_res_connect=True,
        )
    """

    def __init__(
        self,
        input_dim: int,
        expert_out_dims: List[List[int]],
        num_tasks: int,
        num_task_experts: int,
        self_exp_res_connect: bool = True,
        activation: str = "RELU",
    ) -> None:
        super().__init__()
        if len(expert_out_dims) == 0:
            logger.warning(
                "AdaTTSp is noop! size of expert_out_dims which is the number of "
                "extraction layers should be at least 1."
            )
            return
        self.num_extraction_layers: int = len(expert_out_dims)
        self.num_tasks = num_tasks
        self.num_task_experts = num_task_experts
        self.total_experts_per_layer: int = num_task_experts * num_tasks
        self.self_exp_res_connect = self_exp_res_connect
        self.experts = torch.nn.ModuleList()
        self.gate_weights = torch.nn.ModuleList()

        self_exp_weight_list = []
        layer_input_dim = input_dim
        for expert_out_dim in expert_out_dims:
            self.experts.append(
                torch.nn.ModuleList(
                    [
                        MLP(layer_input_dim, expert_out_dim, activation)
                        for i in range(self.total_experts_per_layer)
                    ]
                )
            )

            self.gate_weights.append(
                torch.nn.ModuleList(
                    [
                        torch.nn.Sequential(
                            torch.nn.Linear(
                                layer_input_dim, self.total_experts_per_layer
                            ),
                            torch.nn.Softmax(dim=-1),
                        )
                        for _ in range(num_tasks)
                    ]
                )
            )  # self.gate_weights is of shape L X T, after we loop over all layers.

            if self_exp_res_connect and num_task_experts > 1:
                params = torch.empty(num_tasks, num_task_experts)
                scale = sqrt(1.0 / num_task_experts)
                torch.nn.init.uniform_(params, a=-scale, b=scale)
                self_exp_weight_list.append(torch.nn.Parameter(params))

            layer_input_dim = expert_out_dim[-1]

        self.self_exp_weights = nn.ParameterList(self_exp_weight_list)

    def forward(
        self,
        inputs: torch.Tensor,
    ) -> torch.Tensor:
        for layer_i in range(self.num_extraction_layers):
            # all task expert outputs.
            experts_out = torch.stack(
                [
                    expert(inputs[:, expert_i // self.num_task_experts, :])
                    for expert_i, expert in enumerate(self.experts[layer_i])
                ],
                dim=1,
            )  # [B * E (total experts) * D_out]

            gates = torch.stack(
                [
                    gate_weight(
                        inputs[:, task_i, :]
                    )  #  W ([B, D]) * S ([D, E]) -> G, dim is [B, E]
                    for task_i, gate_weight in enumerate(self.gate_weights[layer_i])
                ],
                dim=1,
            )  # [B, T, E]
            fused_experts_out = torch.bmm(
                gates,
                experts_out,
            )  # [B, T, E] X [B * E (total experts) * D_out] -> [B, T, D_out]

            if self.self_exp_res_connect:
                if self.num_task_experts > 1:
                    # residual from the linear combination of tasks' own experts.
                    self_exp_weighted = torch.einsum(
                        "te,bted->btd",
                        self.self_exp_weights[layer_i],
                        experts_out.view(
                            experts_out.size(0),
                            self.num_tasks,
                            self.num_task_experts,
                            -1,
                        ),  # [B * E (total experts) * D_out] -> [B * T * E_task * D_out]
                    )  #  bmm: [T * E_task] X [B * T * E_task * D_out] -> [B, T, D_out]

                    fused_experts_out = (
                        fused_experts_out + self_exp_weighted
                    )  # [B, T, D_out]
                else:
                    fused_experts_out = fused_experts_out + experts_out

            inputs = fused_experts_out

        return inputs


class AdaTTWSharedExps(nn.Module):
    """
    paper title: "AdaTT: Adaptive Task-to-Task Fusion Network for Multitask Learning in Recommendations"
    paper link: https://doi.org/10.1145/3580305.3599769
    Call Args:
        inputs: inputs is a tensor of dimension
            [batch_size, self.num_tasks, self.input_dim].
            Experts in the same module share the same input.
        outputs dimensions: [B, T, D_out]

    Example::
        AdaTTWSharedExps(
            input_dim=256,
            expert_out_dims=[[128, 128]],
            num_tasks=8,
            num_shared_experts=1,
            num_task_experts=2,
            self_exp_res_connect=True,
        )
    """

    def __init__(
        self,
        input_dim: int,
        expert_out_dims: List[List[int]],
        num_tasks: int,
        num_shared_experts: int,
        num_task_experts: Optional[int] = None,
        num_task_expert_list: Optional[List[int]] = None,
        # Set num_task_expert_list for experimenting with a flexible number of
        # experts for different task_specific units.
        self_exp_res_connect: bool = True,
        activation: str = "RELU",
    ) -> None:
        super().__init__()
        if len(expert_out_dims) == 0:
            logger.warning(
                "AdaTTWSharedExps is noop! size of expert_out_dims which is the number of "
                "extraction layers should be at least 1."
            )
            return
        self.num_extraction_layers: int = len(expert_out_dims)
        self.num_tasks = num_tasks
        assert (num_task_experts is None) ^ (num_task_expert_list is None)
        if num_task_experts is not None:
            self.num_expert_list = [num_task_experts for _ in range(num_tasks)]
        else:
            # num_expert_list is guaranteed to be not None here.
            # pyre-ignore
            self.num_expert_list: List[int] = num_task_expert_list
        self.num_expert_list.append(num_shared_experts)

        self.total_experts_per_layer: int = sum(self.num_expert_list)
        self.self_exp_res_connect = self_exp_res_connect
        self.experts = torch.nn.ModuleList()
        self.gate_weights = torch.nn.ModuleList()

        layer_input_dim = input_dim
        for layer_i, expert_out_dim in enumerate(expert_out_dims):
            self.experts.append(
                torch.nn.ModuleList(
                    [
                        MLP(layer_input_dim, expert_out_dim, activation)
                        for i in range(self.total_experts_per_layer)
                    ]
                )
            )

            num_full_active_modules = (
                num_tasks
                if layer_i == self.num_extraction_layers - 1
                else num_tasks + 1
            )

            self.gate_weights.append(
                torch.nn.ModuleList(
                    [
                        torch.nn.Sequential(
                            torch.nn.Linear(
                                layer_input_dim, self.total_experts_per_layer
                            ),
                            torch.nn.Softmax(dim=-1),
                        )
                        for _ in range(num_full_active_modules)
                    ]
                )
            )  # self.gate_weights is a 2d module list of shape L X T (+ 1), after we loop over all layers.

            layer_input_dim = expert_out_dim[-1]

        self_exp_weight_list = []
        if self_exp_res_connect:
            # If any tasks have number of experts not equal to 1, we learn linear combinations of native experts.
            if any(num_experts != 1 for num_experts in self.num_expert_list):
                for i in range(num_tasks + 1):
                    num_full_active_layer = (
                        self.num_extraction_layers - 1
                        if i == num_tasks
                        else self.num_extraction_layers
                    )
                    params = torch.empty(
                        num_full_active_layer,
                        self.num_expert_list[i],
                    )
                    scale = sqrt(1.0 / self.num_expert_list[i])
                    torch.nn.init.uniform_(params, a=-scale, b=scale)
                    self_exp_weight_list.append(torch.nn.Parameter(params))

        self.self_exp_weights = nn.ParameterList(self_exp_weight_list)

        self.expert_input_idx: List[int] = []
        for i in range(num_tasks + 1):
            self.expert_input_idx.extend([i for _ in range(self.num_expert_list[i])])

    def forward(
        self,
        inputs: torch.Tensor,
    ) -> torch.Tensor:
        for layer_i in range(self.num_extraction_layers):
            num_full_active_modules = (
                self.num_tasks
                if layer_i == self.num_extraction_layers - 1
                else self.num_tasks + 1
            )
            # all task expert outputs.
            experts_out = torch.stack(
                [
                    expert(inputs[:, self.expert_input_idx[expert_i], :])
                    for expert_i, expert in enumerate(self.experts[layer_i])
                ],
                dim=1,
            )  # [B * E (total experts) * D_out]

            # gate weights for fusing all experts.
            gates = torch.stack(
                [
                    gate_weight(inputs[:, i, :])  #  [B, D] * [D, E] -> [B, E]
                    for i, gate_weight in enumerate(self.gate_weights[layer_i])
                ],
                dim=1,
            )  # [B, T (+ 1), E]

            # add all expert gate weights with native expert weights.
            if self.self_exp_res_connect:
                prev_idx = 0
                use_unit_naive_weights = all(
                    num_expert == 1 for num_expert in self.num_expert_list
                )
                for module_i in range(num_full_active_modules):
                    next_idx = self.num_expert_list[module_i] + prev_idx
                    if use_unit_naive_weights:
                        gates[:, module_i, prev_idx:next_idx] += torch.ones(
                            1, self.num_expert_list[module_i]
                        )
                    else:
                        gates[:, module_i, prev_idx:next_idx] += self.self_exp_weights[
                            module_i
                        ][layer_i].unsqueeze(0)
                    prev_idx = next_idx

            fused_experts_out = torch.bmm(
                gates,
                experts_out,
            )  # [B, T (+ 1), E (total)] X [B * E (total) * D_out] -> [B, T (+ 1), D_out]

            inputs = fused_experts_out

        return inputs


class MLP(nn.Module):
    """
    Args:
        input_dim (int):
        mlp_arch (List[int]):
        activation (str):

    Call Args:
        input (torch.Tensor): tensor of shape (B, I)

    Returns:
        output (torch.Tensor): MLP result

    Example::

        mlp = MLP(100, [100])

    """

    def __init__(
        self,
        input_dim: int,
        mlp_arch: List[int],
        activation: str = "RELU",
        bias: bool = True,
    ) -> None:
        super().__init__()

        mlp_net = []
        for mlp_dim in mlp_arch:
            mlp_net.append(
                nn.Linear(in_features=input_dim, out_features=mlp_dim, bias=bias)
            )
            if activation == "RELU":
                mlp_net.append(nn.ReLU())
            else:
                raise ValueError("only RELU is included currently")
            input_dim = mlp_dim
        self.mlp_net = nn.Sequential(*mlp_net)

    def forward(
        self,
        input: torch.Tensor,
    ) -> torch.Tensor:
        return self.mlp_net(input)


class SharedBottom(nn.Module):
    def __init__(
        self, input_dim: int, hidden_dims: List[int], num_tasks: int, activation: str
    ) -> None:
        super().__init__()
        self.bottom_projection = MLP(input_dim, hidden_dims, activation)
        self.num_tasks: int = num_tasks

    def forward(
        self,
        input: torch.Tensor,
    ) -> torch.Tensor:
        # input dim [T, D_in]
        # output dim [B, T, D_out]
        return self.bottom_projection(input).unsqueeze(1).expand(-1, self.num_tasks, -1)


class CrossStitch(torch.nn.Module):
    """
    cross-stitch
    paper title: "Cross-stitch Networks for Multi-task Learning".
    paper link: https://openaccess.thecvf.com/content_cvpr_2016/papers/Misra_Cross-Stitch_Networks_for_CVPR_2016_paper.pdf
    """

    def __init__(
        self,
        input_dim: int,
        expert_archs: List[List[int]],
        num_tasks: int,
        activation: str = "RELU",
    ) -> None:
        super().__init__()
        self.num_layers: int = len(expert_archs)
        self.num_tasks = num_tasks
        self.experts = torch.nn.ModuleList()
        self.stitchs = torch.nn.ModuleList()

        expert_input_dim = input_dim
        for layer_ind in range(self.num_layers):
            self.experts.append(
                torch.nn.ModuleList(
                    [
                        MLP(
                            expert_input_dim,
                            expert_archs[layer_ind],
                            activation,
                        )
                        for _ in range(self.num_tasks)
                    ]
                )
            )

            self.stitchs.append(
                torch.nn.Linear(
                    self.num_tasks,
                    self.num_tasks,
                    bias=False,
                )
            )

            expert_input_dim = expert_archs[layer_ind][-1]

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        """
        input dim [B, T, D_in]
        output dim [B, T, D_out]
        """
        x = input

        for layer_ind in range(self.num_layers):
            expert_out = torch.stack(
                [
                    expert(x[:, expert_ind, :])  # [B, D_out]
                    for expert_ind, expert in enumerate(self.experts[layer_ind])
                ],
                dim=1,
            )  # [B, T, D_out]

            stitch_out = self.stitchs[layer_ind](expert_out.transpose(1, 2)).transpose(
                1, 2
            )  # [B, T, D_out]

            x = stitch_out

        return x


class MLMMoE(torch.nn.Module):
    """
    Multi-level Multi-gate Mixture of Experts
    This code implements a multi-level extension of the MMoE model, as described in the
    paper titled "Modeling Task Relationships in Multi-task Learning with Multi-gate
    Mixture-of-Experts".
    Paper link: https://dl.acm.org/doi/10.1145/3219819.3220007
    To run the original MMoE, use only one fusion level. For example, set expert_archs as
    [[96, 48]].
    To configure multiple fusion levels, set expert_archs as something like [[96], [48]].
    """

    def __init__(
        self,
        input_dim: int,
        expert_archs: List[List[int]],
        gate_archs: List[List[int]],
        num_tasks: int,
        num_experts: int,
        activation: str = "RELU",
    ) -> None:
        super().__init__()
        self.num_layers: int = len(expert_archs)
        self.num_tasks: int = num_tasks
        self.num_experts = num_experts
        self.experts = torch.nn.ModuleList()
        self.gates = torch.nn.ModuleList()

        expert_input_dim = input_dim
        for layer_ind in range(self.num_layers):
            self.experts.append(
                torch.nn.ModuleList(
                    [
                        MLP(
                            expert_input_dim,
                            expert_archs[layer_ind],
                            activation,
                        )
                        for _ in range(self.num_experts)
                    ]
                )
            )
            self.gates.append(
                torch.nn.ModuleList(
                    [
                        torch.nn.Sequential(
                            MLP(
                                input_dim,
                                gate_archs[layer_ind],
                                activation,
                            ),
                            torch.nn.Linear(
                                gate_archs[layer_ind][-1]
                                if gate_archs[layer_ind]
                                else input_dim,
                                self.num_experts,
                            ),
                            torch.nn.Softmax(dim=-1),
                        )
                        for _ in range(
                            self.num_experts
                            if layer_ind < self.num_layers - 1
                            else self.num_tasks
                        )
                    ]
                )
            )
            expert_input_dim = expert_archs[layer_ind][-1]

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        """
        input dim [B, D_in]
        output dim [B, T, D_out]
        """
        x = input.unsqueeze(1).expand([-1, self.num_experts, -1])  # [B, E, D_in]

        for layer_ind in range(self.num_layers):
            expert_out = torch.stack(
                [
                    expert(x[:, expert_ind, :])  # [B, D_out]
                    for expert_ind, expert in enumerate(self.experts[layer_ind])
                ],
                dim=1,
            )  # [B, E, D_out]

            gate_out = torch.stack(
                [
                    gate(input)  # [B, E]
                    for gate_ind, gate in enumerate(self.gates[layer_ind])
                ],
                dim=1,
            )  # [B, T, E]

            gated_out = torch.matmul(gate_out, expert_out)  # [B, T, D_out]

            x = gated_out
        return x


class PLE(nn.Module):
    """
    PLE module is based on the paper "Progressive Layered Extraction (PLE): A
    Novel Multi-Task Learning (MTL) Model for Personalized Recommendations".
    Paper link: https://doi.org/10.1145/3383313.3412236
    PLE aims to address negative transfer and seesaw phenomenon in multi-task
    learning. PLE distinguishes shared and task-specic experts explicitly and
    adopts a progressive routing mechanism to extract and separate deeper
    semantic knowledge gradually. When there is only one extraction layer, PLE
    falls back to CGC.

    Args:
        input_dim: input embedding dimension
        expert_out_dims (List[List[int]]): dimension of an expert's output at
            each layer. This list's length equals the number of extraction
            layers
        num_tasks: number of tasks
        num_task_experts: number of experts for each task module at each layer.
            * If the number of experts is the same for all tasks, use an
            integer here.
            * If the number of experts is different for different tasks, use a
            list of integers here.
        num_shared_experts: number of experts for shared module at each layer

    Call Args:
        inputs: inputs is a tensor of dimension [batch_size, self.num_tasks + 1,
        self.input_dim]. Task specific module inputs are placed first, followed
        by shared module input. (Experts in the same module share the same input)

    Returns:
        output: output of extraction layer to be feed into task-specific tower
            networks. It's a list of tensors, each of which is for one task.

    Example::
        PLE(
            input_dim=256,
            expert_out_dims=[[128]],
            num_tasks=8,
            num_task_experts=2,
            num_shared_experts=2,
        )

    """

    def __init__(
        self,
        input_dim: int,
        expert_out_dims: List[List[int]],
        num_tasks: int,
        num_task_experts: Union[int, List[int]],
        num_shared_experts: int,
        activation: str = "RELU",
    ) -> None:
        super().__init__()
        if len(expert_out_dims) == 0:
            raise ValueError("Expert out dims cannot be empty list")
        self.num_extraction_layers: int = len(expert_out_dims)
        self.num_tasks = num_tasks
        self.num_task_experts = num_task_experts
        if type(num_task_experts) is int:
            self.total_experts_per_layer: int = (
                num_task_experts * num_tasks + num_shared_experts
            )
        else:
            self.total_experts_per_layer: int = (
                sum(num_task_experts) + num_shared_experts
            )
            assert len(num_task_experts) == num_tasks
        self.num_shared_experts = num_shared_experts
        self.experts = nn.ModuleList()
        expert_input_dim = input_dim
        for expert_out_dim in expert_out_dims:
            self.experts.append(
                nn.ModuleList(
                    [
                        MLP(expert_input_dim, expert_out_dim, activation)
                        for i in range(self.total_experts_per_layer)
                    ]
                )
            )
            expert_input_dim = expert_out_dim[-1]

        self.gate_weights = nn.ModuleList()
        selector_dim = input_dim
        for i in range(self.num_extraction_layers):
            expert_out_dim = expert_out_dims[i]
            # task specific gates.
            if type(num_task_experts) is int:
                gate_weights_in_layer = nn.ModuleList(
                    [
                        nn.Sequential(
                            nn.Linear(
                                selector_dim, num_task_experts + num_shared_experts
                            ),
                            nn.Softmax(dim=-1),
                        )
                        for i in range(num_tasks)
                    ]
                )
            else:
                gate_weights_in_layer = nn.ModuleList(
                    [
                        nn.Sequential(
                            nn.Linear(
                                selector_dim, num_task_experts[i] + num_shared_experts
                            ),
                            nn.Softmax(dim=-1),
                        )
                        for i in range(num_tasks)
                    ]
                )
            # Shared module gates. Note last layer has only task specific module gates for task towers later.
            if i != self.num_extraction_layers - 1:
                gate_weights_in_layer.append(
                    nn.Sequential(
                        nn.Linear(selector_dim, self.total_experts_per_layer),
                        nn.Softmax(dim=-1),
                    )
                )
            self.gate_weights.append(gate_weights_in_layer)

            selector_dim = expert_out_dim[-1]

        if type(self.num_task_experts) is list:
            experts_idx_2_task_idx = []
            for i in range(num_tasks):
                # pyre-ignore
                experts_idx_2_task_idx += [i] * self.num_task_experts[i]
            experts_idx_2_task_idx += [num_tasks] * num_shared_experts
            self.experts_idx_2_task_idx: List[int] = experts_idx_2_task_idx

    def forward(
        self,
        inputs: torch.Tensor,
    ) -> torch.Tensor:
        for layer_i in range(self.num_extraction_layers):
            # all task specific and shared experts' outputs.
            # Note first num_task_experts * num_tasks experts are task specific,
            # last num_shared_experts experts are shared.
            if type(self.num_task_experts) is int:
                experts_out = torch.stack(
                    [
                        self.experts[layer_i][expert_i](
                            inputs[
                                :,
                                # pyre-ignore
                                min(expert_i // self.num_task_experts, self.num_tasks),
                                :,
                            ]
                        )
                        for expert_i in range(self.total_experts_per_layer)
                    ],
                    dim=1,
                )  # [B * E (num experts) * D_out]
            else:
                experts_out = torch.stack(
                    [
                        self.experts[layer_i][expert_i](
                            inputs[
                                :,
                                self.experts_idx_2_task_idx[expert_i],
                                :,
                            ]
                        )
                        for expert_i in range(self.total_experts_per_layer)
                    ],
                    dim=1,
                )  # [B * E (num experts) * D_out]

            gates_out = []
            # Loop over all the gates in the layer. Note for the last layer,
            # there is no shared gating network.
            prev_idx = 0
            for gate_i in range(len(self.gate_weights[layer_i])):
                # This is for shared gating network, which uses all the experts.
                if gate_i == self.num_tasks:
                    selected_matrix = experts_out  # S_share
                # This is for task gating network, which only uses shared and its own experts.
                else:
                    if type(self.num_task_experts) is int:
                        task_experts_out = experts_out[
                            :,
                            # pyre-ignore
                            (gate_i * self.num_task_experts) : (gate_i + 1)
                            # pyre-ignore
                            * self.num_task_experts,
                            :,
                        ]  # task specific experts
                    else:
                        # pyre-ignore
                        next_idx = prev_idx + self.num_task_experts[gate_i]
                        task_experts_out = experts_out[
                            :,
                            prev_idx:next_idx,
                            :,
                        ]  # task specific experts
                        prev_idx = next_idx
                    shared_experts_out = experts_out[
                        :,
                        -self.num_shared_experts :,
                        :,
                    ]  # shared experts
                    selected_matrix = torch.concat(
                        [task_experts_out, shared_experts_out], dim=1
                    )  # S_k with dimension of [B * E_selected * D_out]

                gates_out.append(
                    torch.bmm(
                        self.gate_weights[layer_i][gate_i](
                            inputs[:, gate_i, :]
                        ).unsqueeze(dim=1),
                        selected_matrix,
                    )
                    #  W * S -> G
                    #  [B, 1, E_selected] X [B * E_selected * D_out] -> [B, 1, D_out]
                )
            inputs = torch.cat(gates_out, dim=1)  # [B, T, D_out]

        return inputs


class CentralTaskArch(nn.Module):
    def __init__(
        self,
        mtl_configs: MtlConfigs,
        opts: ArchInputs,
        input_dim: int,
    ) -> None:
        super().__init__()
        self.opts = opts

        assert len(mtl_configs.expert_out_dims) > 0, "expert_out_dims is empty."
        self.num_tasks: int = opts.num_task

        self.mtl_model: str = mtl_configs.mtl_model
        logger.info(f"mtl_model is {mtl_configs.mtl_model}")
        expert_out_dims: List[List[int]] = mtl_configs.expert_out_dims
        # AdaTT-sp
        # consider consolidating the implementation of att_sp and att_g.
        if mtl_configs.mtl_model == "att_sp":
            self.mtl_arch: nn.Module = AdaTTSp(
                input_dim=input_dim,
                expert_out_dims=expert_out_dims,
                num_tasks=self.num_tasks,
                num_task_experts=mtl_configs.num_task_experts,
                self_exp_res_connect=mtl_configs.self_exp_res_connect,
                activation=opts.activation_type,
            )
        # AdaTT-general
        elif mtl_configs.mtl_model == "att_g":
            self.mtl_arch: nn.Module = AdaTTWSharedExps(
                input_dim=input_dim,
                expert_out_dims=expert_out_dims,
                num_tasks=self.num_tasks,
                num_task_experts=mtl_configs.num_task_experts,
                num_shared_experts=mtl_configs.num_shared_experts,
                self_exp_res_connect=mtl_configs.self_exp_res_connect,
                activation=opts.activation_type,
            )
        # PLE
        elif mtl_configs.mtl_model == "ple":
            self.mtl_arch: nn.Module = PLE(
                input_dim=input_dim,
                expert_out_dims=expert_out_dims,
                num_tasks=self.num_tasks,
                num_task_experts=mtl_configs.num_task_experts,
                num_shared_experts=mtl_configs.num_shared_experts,
                activation=opts.activation_type,
            )
        # cross-stitch
        elif mtl_configs.mtl_model == "cross_st":
            self.mtl_arch: nn.Module = CrossStitch(
                input_dim=input_dim,
                expert_archs=mtl_configs.expert_out_dims,
                num_tasks=self.num_tasks,
                activation=opts.activation_type,
            )
        # multi-layer MMoE or MMoE
        elif mtl_configs.mtl_model == "mmoe":
            self.mtl_arch: nn.Module = MLMMoE(
                input_dim=input_dim,
                expert_archs=mtl_configs.expert_out_dims,
                gate_archs=[[] for i in range(len(mtl_configs.expert_out_dims))],
                num_tasks=self.num_tasks,
                num_experts=mtl_configs.num_shared_experts,
                activation=opts.activation_type,
            )
        # shared bottom
        elif mtl_configs.mtl_model == "share_bottom":
            self.mtl_arch: nn.Module = SharedBottom(
                input_dim,
                [dim for dims in expert_out_dims for dim in dims],
                self.num_tasks,
                opts.activation_type,
            )
        else:
            raise ValueError("invalid model type")

        task_modules_input_dim = expert_out_dims[-1][-1]
        self.task_modules: nn.ModuleList = nn.ModuleList(
            [
                nn.Sequential(
                    MLP(
                        task_modules_input_dim, self.opts.task_mlp, opts.activation_type
                    ),
                    torch.nn.Linear(self.opts.task_mlp[-1], 1),
                )
                for i in range(self.num_tasks)
            ]
        )

    def forward(
        self,
        task_arch_input: torch.Tensor,
    ) -> List[torch.Tensor]:
        if self.mtl_model in ["att_sp", "cross_st"]:
            task_arch_input = task_arch_input.unsqueeze(1).expand(
                -1, self.num_tasks, -1
            )
        elif self.mtl_model in ["att_g", "ple"]:
            task_arch_input = task_arch_input.unsqueeze(1).expand(
                -1, self.num_tasks + 1, -1
            )

        task_specific_outputs = self.mtl_arch(task_arch_input)

        task_arch_output = [
            task_module(task_specific_outputs[:, i, :])
            for i, task_module in enumerate(self.task_modules)
        ]

        return task_arch_output

模型实现

代码库实现了以下模型:

  • AdaTT:自适应任务到任务融合网络,论文中提出的模型。
  • MMoE:多门控混合专家模型,一种经典的多任务学习模型。
  • Multi-level MMoE:MMoE的扩展,具有多个层次。
  • PLE:Progressive Layered Extraction模型,另一种多任务学习模型。
  • Cross-stitch:交叉缝合模型,用于多任务学习。
  • Shared-bottom:共享底部网络,多任务学习的基础模型。

AdaTT模型结构

AdaTT模型的核心是一个深度融合网络,它具有多个层次上的专有任务单元和可选的共享融合单元。通过利用残差机制和门控机制进行任务间融合,这些单元可以自适应地同时学习共享知识和专有任务知识 。

AdaTT代码组件

AdaTT代码库的关键组件包括:

  1. 模型定义:定义了AdaTT和其他多任务学习模型的结构。
  2. 训练和评估脚本:提供了训练和评估模型的脚本。
  3. 数据加载器:用于加载和处理多任务学习数据集。
  4. 配置文件:包含模型训练和评估的配置参数。
  5. 实验结果:记录了不同模型在各种任务组上的性能。

AdaTT模型的关键代码段

以下是AdaTT模型中一些关键代码段的示例:

# AdaTT模型的PyTorch实现
class AdaTT(nn.Module):
    def __init__(self, ...):
        # 初始化模型参数
        ...
    
    def forward(self, x, tasks):
        # 前向传播逻辑
        ...

使用方法

要使用AdaTT库,可以按照以下步骤操作:

  1. 安装:克隆或下载代码库,并确保安装了所需的依赖项(如PyTorch)。
  2. 配置:根据需要调整配置文件中的参数。
  3. 数据准备:准备多任务学习数据集,并按照库的要求进行格式化。
  4. 训练:运行训练脚本,开始模型训练。
  5. 评估:使用训练好的模型对测试数据集进行评估

3. AdaTT、MMoE和PLE比较

       AdaTT(Adaptive Task-to-Task Fusion Network)是一种先进的多任务学习模型,它通过自适应融合机制有效地模拟复杂的任务关系,并促进特定任务和共享知识的联合学习。与PLE(Progressive Layered Extraction)和MMoE(Multi-gate Mixture-of-Experts)等其他多任务学习模型相比,AdaTT有其独特的优势和特点。

AdaTT与PLE的比较:

  • 结构上的区别:AdaTT在PLE的基础上增加了一个结构简单的专家网络(NativeExpert),这部分与AllExpertGF(相当于PLE中的共享专家+独自专家)进行加和,增强了每个融合级别后任务特定学习的鲁棒性。
  • 实验效果:在任务相关性低、高或多样化的实验中,AdaTT均表现出最优的性能,其Normalized Entropy(NE)最低,效果最好。
  • 线上实验:论文中没有给出AdaTT的线上实验效果,但根据经验,增加线性专家网络结构通常不会带来太大的效果变动,往往持平。

AdaTT与MMoE的比较:

  • 模型结构:MMoE模型通过门控网络(Gating Networks)来平衡多任务学习中的不同任务,而AdaTT则通过自适应融合机制和残差连接来实现这一目标。
  • 任务相关性:MMoE在处理任务相关性方面表现出色,但AdaTT在任务相关性多样化的情况下,通过其自适应融合机制,能够更好地捕捉任务之间的关系。

AdaTT的特点:

  • 自适应融合机制:AdaTT利用自适应融合机制来动态调整不同任务之间的信息交流和知识共享。
  • 多层次融合:AdaTT在多个层次上进行任务特定的和共享的知识融合,有助于模型捕捉不同粒度的特征表示。
  • 残差连接:AdaTT使用残差连接来提高深层网络的训练效率,缓解梯度消失问题。
  • 门控机制:AdaTT使用门控机制来控制不同任务之间的信息流动,使得模型可以自适应地选择对当前任务最有用的信息。

结论:

        AdaTT在多任务学习领域展现出了强大的性能和广泛的应用潜力。与PLE和MMoE等模型相比,AdaTT通过其自适应融合机制和多层次融合策略,有效地处理了不同任务之间的关系,并促进了知识的共享和特定任务的学习。特别是在任务相关性多样化的情况下,AdaTT能够更好地适应和学习,展现出其独特的优势。

 

  • 19
    点赞
  • 13
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值