【长时间序列预测】Autoformer 代码详解之[4]自相关机制

先看论文中的图:

显然,AutoCorrelation 完全可以替换 self-attention 。输入都是 q,k, v ,输出是一个 V。

直接看代码:

    def forward(self, queries, keys, values, attn_mask):
        B, L, H, E = queries.shape
        _, S, _, D = values.shape
        if L > S:
            zeros = torch.zeros_like(queries[:, :(L - S), :]).float()
            values = torch.cat([values, zeros], dim=1)
            keys = torch.cat([keys, zeros], dim=1)
        else:
            values = values[:, :L, :, :]
            keys = keys[:, :L, :, :]

        # period-based dependencies
        q_fft = torch.fft.rfft(queries.permute(0, 2, 3, 1).contiguous(), dim=-1) # 32, 4, 128, 49
        k_fft = torch.fft.rfft(keys.permute(0, 2, 3, 1).contiguous(), dim=-1) # 32, 4, 128, 49
        res = q_fft * torch.conj(k_fft) # 32, 4, 128, 49
        corr = torch.fft.irfft(res, dim=-1) # 32, 4, 128, 96

一般情况下,queries, keys, values 的shape都是一样的,比如:[32, 96, 4, 128]。

当然,解码器的第二个AC,keys和 values 的shape 一样,比如:[32, 96, 4, 128]。queries 的为[32, 48+192, 4, 128]。通过上面 if 语句会把 keys和 values 的长度补0 到 queries 一样长。

本文以 queries, keys, values 的shape:[32, 96, 4, 128] 说明。

显然, fft 并没有可说的。主要关注的在 time_delay_agg_training 函数。

            V = self.time_delay_agg_training(values.permute(0, 2, 3, 1).contiguous(), corr).permute(0, 3, 1, 2)
    def time_delay_agg_training(self, values, corr): # 32 4 128 96
        """
        SpeedUp version of Autocorrelation (a batch-normalization style design)
        This is for the training phase.
        """
        # 把批次看成1 进行分析,一目了然
        head = values.shape[1] # 4
        channel = values.shape[2] # 128 
        length = values.shape[3] # 96
        # find top k
        top_k = int(self.factor * math.log(length)) # 1* ln(96) = 4
        mean_value = torch.mean(torch.mean(corr, dim=1), dim=1) # 32, 96 每次求mean 那个维度就会消失
        print(torch.topk(torch.mean(mean_value, dim=0), top_k, dim=-1)) # 4个值
        index = torch.topk(torch.mean(mean_value, dim=0), top_k, dim=-1)[1] # 4 假设得到:[3, 4, 5, 2] # 这里[1]后得到的是index  ## 批次维度求均值得到 shape: 96  再取 top_k
        weights = torch.stack([mean_value[:, index[i]] for i in range(top_k)], dim=-1) # 32 4  把96个里面最大的4个值取出来
        # update corr
        tmp_corr = torch.softmax(weights, dim=-1) # 32 4  这4个值归一化
        # aggregation
        tmp_values = values # 32 4 128 96
        delays_agg = torch.zeros_like(values).float() # 32 4 128 96 都是0
        for i in range(top_k):
            pattern = torch.roll(tmp_values, -int(index[i]), -1) # 32 4 128  96  从序列那个维度移动  # index 对应移动的步长
            delays_agg = delays_agg + pattern * \
                         (tmp_corr[:, i].unsqueeze(1).unsqueeze(1).unsqueeze(1).repeat(1, head, channel, length)) # pattern 乘的是 那个 R_q,k 后面部分表示对应元素相乘 后面的就是那个权重  unsqueeze repeat 是为了形状相同 可以相乘 # tmp_corr[:, i] 这里不考虑batch_size 时是一个数 把这个数repeat 96次
        return delays_agg # 32 4 128 96

上图,假设top_k = 2, 序列长度 = 10。


该模型更新完毕。

更多参考:请参考前文论文阅读部分的参考文献。作者回答了为什么用 torch.roll 。使用 top_k的目的还是降低计算复杂度。

  • 5
    点赞
  • 24
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 9
    评论
评论 9
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

理心炼丹

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值