TESTAM:时间增强的时空注意力混合专家模型

TESTAM:时间增强的时空注意力混合专家模型
From:ICLR 2024

Abstract

最近的工作主要集中在具有自适应图嵌入或图注意的空间建模上,很少关注原位建模的时间特征和有效性。本文提出了时间增强的时空注意力模型,通过3名专家进行时间建模,静态图的时空建模,动态图的时空依赖性建模。

通过引入不同专家并正确规划,TESTAM可以更好地捕获各种情况下的交通模式,包括孤立的道路、高度互联的道路,非重复事件的情况。为了更好地实现规划,我们将门控问题重新表述为带有伪标签的分类任务。

在 METR-LA、PEMS-BAY 和 EXPY-TKY 三个公共交通网络数据集上的实验结果表明,TESTAM 由于能够更好地对重复性和非重复性交通模式进行建模,在准确性方面优于13 种现有方法。

1 Introduction

不同的空间建模方法针对不同的情况具有一定的优势。例如,可学习的静态图建模在可重复的交通流情况下由于动态图;动态空间建模对于非重复性交通流具有优势,例如事故或突然的速度变化。Park 等人(2020)揭示,保留道路信息本身可以提高预测性能,这意味着需要仅时间建模。Jin 等人(2023)表明,基于时间相似性构建的静态图与动态图建模方法相结合可以提高性能。
本文提出了一种时间增强的时空注意力模型(TESTAM),这是一种新颖的专家混合模型(MoE),可以实现现场交通预测。该模型由3位专家模型组成,每位专家的空间建模方式不同,分别为:

  1. 没有空间建模
  2. 可学习的静态图进行空间建模
  3. 动态图建模和一个门控网络

每位专家都由transformer-based blocks用于空间模型方法。
门控网络通过每个专家的最后隐藏状态和输入交通条件,生成用于现场(in-situ)交通预测的候选路径(routes)。通过两个损失函数训练门控网络,来获得最好的路径(route)。

本文的contributions:

  • 提出了一种名为 TESTAM 的新型混合专家模型,用于交通预测,该模型具有多样化的图结构,可以提高不同交通条件下(包括重复和非重复情况)的准确性。
  • 将门控问题重新表述为分类问题,以使模型更好地了解交通状况并在训练期间选择空间建模方法(即专家)。
  • 使用三个真实世界数据集对最先进模型进行的实验结果表明,TESTAM 在数量和质量上都优于现有方法。

2 Related Work

2.1 Traffic forecasting

为了应对这些挑战,我们设计了 TESTAM,使用专家混合技术根据交通环境改变其空间建模方法。

2.2 Mixture of experts

TESTAM 与现有 MoE 有两个主要区别:它利用具有不同空间建模方法的专家来实现更好的泛化,并且可以使用两个损失函数进行优化 - 一个用于避免最差的路径,另一个用于选择最佳路径以实现更好的专业化。

3 Methods

3.1 Preliminaries

问题定义 道路网络: G = ( V , ε , A ) \mathcal{G}=(\mathcal{V},\mathcal{\varepsilon},\mathcal{A}) G=(V,ε,A)

  • V \mathcal{V} V是所有道路的集合, ∣ V ∣ = N |\mathcal{V}|=N V=N
  • ε \varepsilon ε表示道路之间连通性的一组边
  • A ∈ R N × N \mathcal{A}\in\mathbb{R}^{N×N} ARN×N是图的拓扑矩阵(邻接矩阵)

给定道路网络,我们将问题表述为特殊的多元时间序列预测问题,基于 T ′ T^{'} T个历史输入图信号预测未来 T T T个图信号:

[ X G ( t − T ′ + 1 ) , . . . , X G ( t ) ] ⟶ f ( . ) [ X G ( t + 1 ) , . . . , X G ( t + T ) ] [X^{(t-T^{'}+1)}_{\mathcal{G}},..., X^{(t)}_{\mathcal{G}}] \stackrel{f(.)}{\longrightarrow} [X^{(t+1)}_{\mathcal{G}},..., X^{(t+T)}_{\mathcal{G}}] [XG(tT+1),...,XG(t)]f(.)[XG(t+1),...,XG(t+T)]

X G i ∈ R N × C X^{i}_{\mathcal{G}} \in \mathbb{R}^{N×C} XGiRN×C,N是传感器数量,C是输入特征的数量

旨在训练一个映射函数 f ( . ) f(.) f(.)
f ( . ) : R T ′ × N × C ⟶ R T × N × C f(.): \mathbb{R}^{T^{'}×N×C}{\longrightarrow}\mathbb{R}^{T×N×C} f(.):RT×N×CRT×N×C

交通预测中的空间建模方法 在交通预测中,我们可以将空间建模方法分为四类:

  1. 使用单位矩阵(例如,多元时间序列预测)
  2. 使用预定义的邻接矩阵
  3. 使用可训练的邻接矩阵
  4. 使用注意力机制(例如,无需先验知识的动态空间建模)

传统上,图拓扑A是通过经验法则构建的,包括反距离和余弦相似度。然而,这些凭经验构建的图结构不一定是最优的,因此会导致空间建模质量较差。为了应对这一挑战,提出了一系列研究来捕获隐藏的空间信息。

具体来说,使用可训练函数 g ( ⋅ , θ ) g(·, θ) g(⋅,θ) 来推导出最佳拓扑表示 A ~ \widetilde{A} A 为:
A ~ = s o f t m a x ( r e l u ( g ( X ( t ) , θ ) , g ( X ( t ) , θ ) T ) \widetilde{A}=softmax(relu(g(X^{(t)},\theta), g(X^{(t)},\theta)^{T}) A =softmax(relu(g(X(t),θ),g(X(t),θ)T)

上述空间建模方式根据 g ( ⋅ , θ ) g(·, θ) g(⋅,θ) 是否依赖于 X ( t ) X^{(t)} X(t)可分为两个子类别:

  1. 定义 g ( ⋅ , θ ) = E ∈ R N × e g(·, θ)=E\in\mathbb{R^{N×e}} g(⋅,θ)=ERN×e,这与时间无关、对噪声不敏感,但是较少现场建模。(less in-situ modeling)
  2. 时变图结构建模, g ( H ( t ) , θ ) = H ( t ) W , W ∈ R d × e g(H^{(t)},\theta)=H^{(t)}W,W\in\mathbb{R^{d×e}} g(H(t),θ)=H(t)W,WRd×e,将隐藏状态投影到另一个嵌入空间。理想情况下,该方法可以对图拓扑的动态变化进行建模,但它对噪声敏感。

为了降低噪声敏感性并获得时变图结构,Zheng 等人(2020)采用空间注意力机制进行交通流预测。给定节点 i i i 的输入 H i H_i Hi 及其空间邻居 N i N_i Ni,他们使用多头注意力机制计算空间注意力,如下所示:
G A T 公式 H i ∗ = C o n c a t ( o i ( 1 ) , . . . , o i ( K ) ) W o ; o i ( K ) = GAT公式 H^*_i=Concat(o_i^{(1)},...,o_i^{(K)})W^o; o_i^{(K)}= GAT公式Hi=Concat(oi(1),...,oi(K))Wo;oi(K)=

虽然有效,但这些基于注意力的方法仍然受到不规则空间建模的影响,例如,不太准确的注意力,无关空间关系的均匀分布的无信息注意力。

3.2 Model Architecture

尽管 Transformer 是用于时间序列预测的成熟结构,但它在用于时空建模时存在一些问题:它们不考虑空间建模,消耗大量内存资源,并且存在由自回归解码过程引起的瓶颈问题。 Park 等人(2020)引入了一种带有图注意力(GAT)的改进 Transformer 模型,但该模型仍然具有自回归特性。为了消除自回归特征,同时保留编码器-解码器架构的优势,TESTAM 通过时间增强注意力和时间信息嵌入来转移注意力域。

如图1(左)是每个专家层的整体结构,除了时间信息嵌入之外,其还包含四个子层:时间注意力、空间建模、时间增强注意力和逐点前馈神经网络。每个子层通过跳跃连接到旁路。为了提高泛化能力,我们在每个子层之后应用层归一化。
所有专家都具有相同的隐藏大小和层数,仅在空间建模方法方面有所不同。
图1(中)是TESTAM的工作流程和路径机制(routing mechanism)。
图1(右)是三种空间建模方法。黑线表示空间连通性,红线表示空间连通性对应的信息流。

TESTAM模型结构
Temporal Information Embedding

由于时间特征可作为具有特定周期性的全局位置,因此模型省掉了transformer架构中的位置嵌入。此外,不采用归一化时间特征,我们使用Time2Vec嵌入进行周期性和线性建模。
T I M ( τ ) [ i ] = { w i v ( τ ) [ i ] + ϕ i if  i = 0 F ( w i v ( τ ) [ i ] + ϕ i ) if  1 ≤ i ≤ h − 1 TIM(\tau)[i] = \begin{cases} w_iv(\tau)[i]+\phi_i &\text{if } i=0 \\ \mathcal{F}(w_iv(\tau)[i]+\phi_i) &\text{if } 1\leq i \leq h-1 \end{cases} TIM(τ)[i]={wiv(τ)[i]+ϕiF(wiv(τ)[i]+ϕi)if i=0if 1ih1

Temporal Attention

TESTAM中的时间注意力与transformer中的相同。注意力是时间建模的一个有吸引力的解决方案,因为与基于循环单元或基于卷积的时间建模不同,它可用于直接关注跨时间步长的特征,没有任何限制。时间注意力允许并行计算,有利于长期序列建模。此外,它在局部性和顺序性方面具有较少的归纳偏差。虽然强归纳偏差可以帮助训练,但较少的归纳偏差可以实现更好的泛化。
此外,对于交通预测问题,道路之间的因果关系是一个不可避免的因素(Jin et al, 2023),在存在强归纳偏差(例如顺序性或局部性)的情况下无法轻松建模。

Spatial Modeling Layer

三种空间建模层提供给专家,如图1(中)所示:

  • 使用单位矩阵的空间建模(即无空间建模)
  • 使用可学习邻接矩阵的空间建模(方程 1)
  • 以及带有注意力的空间建模(等式 2 和等式 3)

受到记忆增强图结构学习成功的启发,我们提出了一种改进的元图学习器,可以从空间图建模和门控网络中学习原型。

Time-Enhanced Attention

3.3 Gating Networks

基于记忆的门控网络和带有基于回归误差的伪标签的两种分类损失。

4 Experiments

三个基准数据集METR-LA、PEMS-BAY 和 EXPY-TKY。METR-LA 和 PEMS-BAY 包含洛杉矶高速公路上的 207 个传感器和湾区 325 个传感器记录的四个月速度数据(Li et al, 2018)。EXPY-TKY 包含从日本东京的 1843 个链路收集的三个月速度数据。

5 Conclusion

在本文中,我们提出了时间增强时空注意力模型(TESTAM),这是一种具有注意力的新型专家混合模型,可以在重复和非重复情况下进行有效的原位空间建模。通过将路径问题转化为分类任务,TESTAM 可以结合各种交通状况并选择最合适的空间建模方法。 TESTAM 在三个现实数据集中实现了优于现有流量预测模型的性能:METR-LA、PEMS-BAY 和 EXPY-TKY。使用 EXPY-TKY 数据集获得的结果表明,TESTAM 对于大规模图结构非常有优势,更适用于现实世界的问题。我们还获得了定性结果,以可视化 TESTAM 选择特定图形结构的时间和地点。在未来的工作中,我们计划进一步改进和推广 TESTAM 用于其他时空和多元时间序列预测任务。

代码实现

数据集介绍 - metr-la 及generate_training_data.py

读取metr-la.h5

file_name = "data/metr-la.h5"
df = pd.read_hdf(file_name)
df

数据集预览

数据集的形状为34272×207

  • timestamp:时间戳,表示观测的时间点,共34272个时刻。
  • flow(或类似指标):表示在该时间点上的交通流量,在该数据集中每个时刻可以捕获到207个节点的交通流量数据。

通过运行下述python命令,对数据集进行处理

python generate_training_data.py --output_dir=data/METR-LA --traffic_df_fiilename=data/metr-la.h5 --seq_length_x INPUT_SEQ_LENGTH --seq_length_y PRED_SEQ_LENGTH

详细的处理过程可参考代码或下述帖子:数据集转换

具体操作有:

  1. 添加当前时刻在某一天中的offset
    因为每5分钟采集一次数据,所以一天共采集24×60÷5=288个数据,offset为0-288,归一化到0-1之间。
time_ind = (df.index.values - df.index.values.astype("datetime64[D]")) / np.timedelta64(1, "D")
time_ind.shape,time_ind

输出:
((34272,),
 array([0.        , 0.00347222, 0.00694444, ..., 0.98958333, 0.99305556,
        0.99652778]))
# 查看第一天各时刻的offset
time_ind[:288]

一天offset表示

  1. 将各时刻在一天内的offset值与df的流量数据进行拼接合并
time_in_day = np.tile(time_ind, [1, 207, 1]).transpose((2, 1, 0))
'''
np.tile(time_ind, [1, num_nodes, 1]),
time_ind 是(34272,)的一维向量,遇到tile的时候首先扩展维度
扩展维度一般是shape向左扩展,也即变成(1,1,34272)
然后用tile扩展维度,变成(1,num_nodes,34272)
在经过transpose,第2个维度和第0个维度互换
'''
time_in_day.shape
#(34272, 207, 1)

data = np.expand_dims(df.values, axis=-1)
data.shape
#(34272, 207, 1)
 
data_list = [data]
data_list.append(time_in_day)
data = np.concatenate(data_list, axis=-1)
data.shape
#(34272, 207, 2)

原来的data (34272, 207, 1)
处理后的data (34272, 207, 2)

  1. 生成输入和ground-truth列表
# x_offset: [-11, -10,  -9,  -8,  -7,  -6,  -5,  -4,  -3,  -2,  -1,   0]
# y_offset: [ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12]
x, y = [], []
min_t = abs(min(x_offsets))
max_t = abs(num_samples - abs(max(y_offsets)))  # Exclusive
min_t,max_t
#(11, 34260)

for t in range(min_t, max_t): # 从11到34260,共34249个数据
    x_t = data[t + x_offsets, :]
    y_t = data[t + y_offsets, :]
    x.append(x_t)
    y.append(y_t)
x = np.stack(x, axis=0)
y = np.stack(y, axis=0)
 
x.shape,y.shape
#((34249, 12, 207, 2), (34249, 12, 207, 2))
 
'''
x_t y_t是
[ 0  1  2  3  4  5  6  7  8  9 10 11]
[22 21 20 19 18 17 16 15 14 13 12 11]
**********
[ 1  2  3  4  5  6  7  8  9 10 11 12]
[23 22 21 20 19 18 17 16 15 14 13 12]
**********
[ 2  3  4  5  6  7  8  9 10 11 12 13]
[24 23 22 21 20 19 18 17 16 15 14 13]
**********
[ 3  4  5  6  7  8  9 10 11 12 13 14]
[25 24 23 22 21 20 19 18 17 16 15 14]
**********
[ 4  5  6  7  8  9 10 11 12 13 14 15]
[26 25 24 23 22 21 20 19 18 17 16 15]
**********
[ 5  6  7  8  9 10 11 12 13 14 15 16]
[27 26 25 24 23 22 21 20 19 18 17 16]
**********
一位一位向前滚
即 滑动窗口法处理数据
x为前12h的数据,y为后12h的数据
'''
  1. 按照7:1:2划分train,val,test集
num_samples = x.shape[0]
num_test = round(num_samples * 0.2)
num_train = round(num_samples * 0.7)
num_val = num_samples - num_test - num_train
num_test,num_train,num_val
#(6850, 23974, 3425)
  1. 将训练集、验证集、测试集以npz形式保存至本地
for cat in ["train", "val", "test"]:
        _x, _y = locals()["x_" + cat], locals()["y_" + cat]
        '''
        使用locals()函数动态获取名为x_train, y_train, x_val, y_val, x_test, y_test的变量
        这些变量分别代表训练集、验证集和测试集的输入和输出数据
        '''
        print(cat, "x: ", _x.shape, "y:", _y.shape)
 
        np.savez_compressed(
            os.path.join(args.output_dir, "%s.npz" % cat),
            x=_x,
            y=_y,
            x_offsets=x_offsets.reshape(list(x_offsets.shape) + [1]),
            y_offsets=y_offsets.reshape(list(y_offsets.shape) + [1]),
        )
        '''
        使用numpy.savez_compressed函数将数据保存到压缩文件中,文件名格式为{分类}.npz
        输入数据保存为关键字x。
        输出数据保存为关键字y。
        输入和输出的时间偏移量(x_offsets和y_offsets)也被保存    
        '''

train.py

  1. dataloader是一个字典结构,形式如下:
# 代码调用
dataloader = util.load_dataset(args.data, args.batch_size, args.batch_size, args.batch_size)
scaler = dataloader['scaler'] # scaler包含了数据处理相关的信息,通常是对数据进行归一化的计算参数和方法。

下面是dataloader的结构,是一个字典,包含x_train,x_val,…,train_dataloader等键,说明它们之间的关系,尤其是x_train与train_dataloader的关系,为什么会是这样的结构

{'x_train': array([[[[ 0.51139864,  0.        ],
         [ 0.67811883,  0.        ],
         [ 0.65246957,  0.        ],
         ...,
         [-0.49960972,  0.27777778],
         [ 0.62325791,  0.27777778],
         [-0.09492139,  0.27777778]]]]), 
'y_train': array([[[[6.11250000e+01, 4.16666667e-02],
         [6.70000000e+01, 4.16666667e-02],
         [5.85000000e+01, 4.16666667e-02],
         ...,
         [1.90000000e+01, 3.19444444e-01],
         [6.62500000e+01, 3.19444444e-01],
         [6.48750000e+01, 3.19444444e-01]]]]), 
'x_val': array([[[[ 0.41154973,  0.24305556],
         [ 0.61674382,  0.24305556],
         [ 0.69735578,  0.24305556],
         ..., 
         [ 0.42803854,  0.17013889],
         [ 0.56269716,  0.17013889],
         [ 0.44727549,  0.17013889]]]]),
'y_val': array([[[[67.875     ,  0.28472222],
         [65.75      ,  0.28472222],
         [62.875     ,  0.28472222],
         ...,
         [64.        ,  0.21180556],
         [66.33333333,  0.21180556],
         [65.        ,  0.21180556]]]]), 
'x_test': array([[[[0.48076202, 0.13541667],
         [0.58335906, 0.13541667],
         [0.75435413, 0.13541667],
         ..., 
         [0.63323263, 0.95486111],
         [0.620408  , 0.95486111],
         [0.46010012, 0.95486111]]]]), 
'y_test': array([[[[65.25      ,  0.17708333],
         [60.375     ,  0.17708333],
         [62.125     ,  0.17708333],
         ...,       
         [63.55555556,  0.99652778],
         [68.66666667,  0.99652778],
         [61.77777778,  0.99652778]]]]), 
'train_loader': <util.DataLoader object at 0x7ff7a323abe0>, 
'val_loader': <util.DataLoader object at 0x7ff7a323af70>, 
'test_loader': <util.DataLoader object at 0x7ff7a323aca0>, 
'scaler': <util.StandardScaler object at 0x7ff7a329e5e0>}
  • dataloader中的 ‘x_train’ 和 ‘train_loader’ 是训练集数据的两种不同表达方式。‘x_train’ 是原始的数组表示,而 ‘train_loader’ 是通过数据加载器封装后的对象,用于方便地批量加载训练数据。‘train_loader’ 可以通过调用其提供的方法来访问和迭代训练数据,而 ‘x_train’ 则是直接访问训练数据的原始数组。
  • ‘scaler’:一个标准化器(StandardScaler对象)。这个对象用于对数据进行归一化操作,以确保数据具有统一的尺度和分布。在训练过程中,通常使用训练集的数据来计算归一化所需的参数,并将这些参数应用到训练、验证和测试数据上。这样可以确保在不同数据集上使用相同的归一化方式。

train_loader的batch_size是64,num_batch是375,size是24000,
x_train的shape是(23974,12,207,2)size是119102832,
这些数字之间有什么关系
x_train 的 size 可以计算为 23974 * 12 * 207 * 2 = 119,102,832。
train_loader的size可以计算为 375 * 64 = 24000
num_batches = num_samples / batch_size = 23974 / 64 ≈ 375

  1. 训练过程
for iter, (x, y) in enumerate(dataloader['train_loader'].get_iterator()):
            trainx = torch.Tensor(x).to(device) # torch.Size([64, 12, 207, 2])
            trainx= trainx.transpose(1, 3) # torch.Size([64, 2, 207, 12])
            trainy = torch.Tensor(y).to(device) # torch.Size([64, 12, 207, 2])
            trainy = trainy.transpose(1, 3) # torch.Size([64, 2, 207, 12])
            metrics = engine.train(trainx, trainy[:,0,:,:], i) # torch.Size([64, 207, 12])
            train_loss.append(metrics[0])
            train_mape.append(metrics[1])
            train_rmse.append(metrics[2])
            if iter % args.print_every == 0 :
                log = 'Iter: {:03d}, Train Loss: {:.4f}, Train MAPE: {:.4f}, Train RMSE: {:.4f}'
                print(log.format(iter, train_loss[-1], train_mape[-1], train_rmse[-1]),flush=True)

训练时每个batch有64个数据,每个数据用12个历史值预测未来12个值,
输入模型的x维度为[64,2,207,12],y为[64,207,12],只需要流量数据。

engine = trainer(scaler, args.in_dim, args.seq_length, args.num_nodes, args.nhid, args.dropout,
                         device, args.lr_mul, args.n_warmup_steps, args.quantile, args.is_quantile, args.warmup_epoch)

engine.py

该py文件主要定义了trainer类,该文件结构如下:

engine.py
│
└──  class trainer() 
      ├── def __init__()
      ├── def train()
	  ├── def eval()
	  ├── def get_quantile_label()
      └── def get_label()

def init() 初始化模型、学习率、优化算法、损失函数等
def train() 将module设置为 training mode,计算损失等,返回loss,mape,rmse
def eval() 将模型设置成evaluation模式
PyTorch中model.train()和model.eval()细节分析
def get_quantile_label
def get_label()

model.py

class TESTAM()

self.identity_expert = TemporalModel(hidden_size, num_nodes, in_dim = in_dim - 1, layers = layers, dropout = dropout)
self.adaptive_expert = STModel(hidden_size, self.supports_len, num_nodes, in_dim = in_dim, layers = layers, dropout = dropout)
self.attention_expert = AttentionModel(hidden_size, in_dim = in_dim, layers = layers, dropout = dropout)

self.gate_network = MemoryGate(hidden_size, num_nodes)

定义了三个专家类和一个门控选择类

在具体的类构造方法中:
STModel含有gcn层,QKVAttention层,PositionwiseFeedForward层,
先通过一个线性层,start_linear,从输入维度in_dim,输出维度hidden_size
然后通过layers层带有norm的SkipConnection的QKVAttention和gcn,输入输出维度都是hidden_size
之后通过layers层带有norm的SkipConnection的ffn,输入维度为hidden_size,输出维度为4*hidden_size
最后通过一个ReLU激活函数和线性层,输入维度为hidden_size,输出维度为hidden_size + out_dim

AttentionModel含有QKVAttention层,PositionwiseFeedForward层,没有gcn层

TemporalModel通过TemporalInformationEmbedding对象将输入的时间信息转换为嵌入向量。然后,我们使用torch.einsum函数将嵌入向量和节点特征进行组合。接下来,我们处理速度信息,并将其与嵌入向量进行拼接。然后,我们通过QKVAttention和PositionwiseFeedForward层进行一系列的变换。最后,我们通过一个线性层将输出转换为预测值。

TESTAM
这段代码定义了一个名为TESTAM的模型,它是PyTorch的nn.Module的子类。TESTAM模型包含三个专家模型:identity_expert,adaptive_expert和attention_expert,以及一个gate_network。

在__init__方法中,首先初始化了一些基本参数,如dropout,prob_mul和supports_len。然后,创建了三个专家模型和一个门网络。对于每个专家模型,都会对其参数进行Xavier均匀初始化,这是一种常用的权重初始化方法,可以在训练深度神经网络时帮助我们提高模型的性能。

在forward方法中,首先计算了gate_network的两个权重矩阵We1和We2与其记忆memory的乘积,然后通过ReLU激活函数和softmax函数得到新的支持矩阵new_supports。接着,从输入中提取出时间索引,并计算出当前时间索引和下一个时间索引。然后,通过三个专家模型分别处理输入,得到三个输出和隐藏状态。这三个输出被拼接在一起,形成ind_out。

接下来,通过gate_network计算出门的值,然后创建一个与o_identity形状相同的全零张量out。然后,对于每个专家模型的输出,找出门值最大的部分,将其对应的输出赋值给out。如果prob_mul为真,那么out会与门的最大值相乘。

最后,将out的形状调整为(B,N,T,1),并将其维度进行重新排列。如果模型处于训练状态或者gate_out为真,那么返回out,gate和ind_out,否则只返回out。

总的来说,这个模型通过三个专家模型处理输入,然后通过一个门网络决定如何将这三个专家模型的输出结合在一起。这是一种常见的集成学习策略,可以提高模型的性能。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值