【长时间序列预测】Autoformer 代码详解之[3]模型整体架构分析

1. 模型整体架构图:

 显然,一个编码器-解码器架构。先不考虑batch_sizes,编码器将输入的序列(L,d)编码为(L, d_model),d_model 为设置的参数表示将输入嵌入到多少维度。

编码器输入为:

enc_out:[32, 96, 512] # 假设batch_sizes=32, 输入序列长度=96,输入序列嵌入维度=512

编码器的输出:

enc_out:[32, 96, 512] 。编码器只进行了输入的维度进行了一系列变换,并没有改变长度96。

解码器的输入为:

dec_out:[32, 48 + 192, 512] # lable_len = 48,pred_len = 192。

解码器的输出为:

seasonal_part:[32, 48 + 192, c_out];trend_part:[32, 48 + 192, c_out] # c_out 为数据的实际维度。

最终输出:

outputs:[32, 192, c_out] 

序列分解模块已在前文中进行了详细描述。

图中所谓的 前馈模块是集成在 EncoderLayer 和 DecoderLayer 中的,不过是一些Conv1d -> activation -> dropout -> Conv1d  -> dropout 。

所需注意的只有一点:解码器的第二个自相关模块的输入是来自于编码器的输出作为(k, v),以及前一个解码器的输出作为 q。显然,二者的长度是不一样的。(图中信号的引出表示同一个东西,比如q, k, v来自一个输出就是一个信号复制的3次)

k = v :[32,96,512] 

q: [32, 48 + 192,  512]

所以,AutoCorrelation 类的 forward 函数中:

        B, L, H, E = queries.shape # 32 48+192 4 512/4 # 假设多头=4
        _, S, _, D = values.shape  # 32 96 4 512/4
        if L > S:
            zeros = torch.zeros_like(queries[:, :(L - S), :]).float()
            values = torch.cat([values, zeros], dim=1)
            keys = torch.cat([keys, zeros], dim=1)
        else:
            values = values[:, :L, :, :]
            keys = keys[:, :L, :, :]

因此,采用的方案是把 k, v 后面进行了补0到和q 一样长。

因此,我们知道了编码器和解码器并不会对长度进行修改。

2. 模型的输入

2.1 编码器的输入

编码器输入为:enc_out:[32, 96, 512],是如何得到的呢?

答案通过enc_embedding,

输入包括两部分:

        除了时间戳以外的数据 x_enc:[32,  96,  enc_in]  #enc_in 就是输入数据的维度

        时间戳信息 x_mark_enc:[32,  96,  4]  # 如果 args.embed == 'timeF' ,freq='h' 请看前文的数据处理

enc_embedding是DataEmbedding_wo_pos类的对象。该类的forward 函数中将 x_enc(Conv1d) 与 x_mark_enc(Linear) 的shape都编码为 [32,  96,  512] ,然后相加,即得到编码器的输入 enc_out:[32, 96, 512] 

2.1 解码器的输入

对 x_enc:[32,  96,  enc_in]  进行前文的序列分解得到两个序列:

        seasonal_init:[32,  96,  enc_in] 

              trend_init:[32,  96,  enc_in] 

对 x_enc 求均值,和造一个 0 值序列。

        mean = torch.mean(x_enc, dim=1).unsqueeze(1).repeat(1, self.pred_len, 1) # mean 会把dim=1 那个维度干掉[32, enc_in] -> [32, 1, enc_in] -> [32, 192, enc_in]
        zeros = torch.zeros([x_enc.shape[0], self.pred_len, x_enc.shape[2]], device=x_enc.device) # 源代码中的 x_dec 完全没必要,替换为 x_enc 就行了。之所以那么做,作者是为了统一接口让别的模型一样传递输入。 [32, 192, enc_in]
        # decoder input
        trend_init = torch.cat([trend_init[:, -self.label_len:, :], mean], dim=1)
        seasonal_init = torch.cat([seasonal_init[:, -self.label_len:, :], zeros], dim=1)

        seasonal_init:[32,  48 + 192,  enc_in] 

              trend_init:[32,  48 + 192,  enc_in] 

dec_embedding 嵌入 与上文的 enc_embedding 一样。输入:

        seasonal_init:[32,  48 + 192,  enc_in] 

         x_mark_dec:[32,  48 + 192,  4] 

dec_out: [32,  48 + 192,  512] 

3. 解码器的加法怎么做的?

其实论文中的架构图简化了。因此,我重画了解码器部分。

4. 架构总结 

        整体的架构为编码器-解码器结构,但是二者的输入却差距很大。

        编码器的输入是整个输入序列,一层编码器内部包含两个序列分解模块,两个序列分解模块都丢弃了使用移动平均得到的Trend,想法是把除了trend的其他剩余部分分离出来,专注于剩余部分的建模。

        解码器是整个模型的核心,输入包含两部分,序列分解后的trend部分和其他剩余部分。解码器的上半部分试图把剩余部分进一步把趋势和剩余部分分离开来。下半部分则直接把趋势和从剩余部分分离出来的趋势进行求和。上半部分相当于建模了剩余部分。这里的缺点是:原始趋势+解码器分离出来的趋势并没有约束到和原来趋势一样。或许可以革新序列分解模块,并且增加一个loss,让二者的和的目标值为移动平均或者别的算法得到的趋势值,那么才是真正的黑盒趋势分解模型。并且编码器有没有必要也是一个可以考虑的问题。


参考:

GitHub - thuml/Autoformer: About Code release for "Autoformer: Decomposition Transformers with Auto-Correlation for Long-Term Series Forecasting" (NeurIPS 2021), https://arxiv.org/abs/2106.13008About Code release for "Autoformer: Decomposition Transformers with Auto-Correlation for Long-Term Series Forecasting" (NeurIPS 2021), https://arxiv.org/abs/2106.13008 - GitHub - thuml/Autoformer: About Code release for "Autoformer: Decomposition Transformers with Auto-Correlation for Long-Term Series Forecasting" (NeurIPS 2021), https://arxiv.org/abs/2106.13008https://github.com/thuml/Autoformer

计划更新:[4]模型部件之自相关层-AC

  • 5
    点赞
  • 39
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

理心炼丹

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值