【AutoFormer 源码理解】注意力模块调用关系

  🟢  编码器 结构  →  调用 Encoder forward

enc_out, attns = self.encoder(enc_out, attn_mask=enc_self_mask)

编码器 init

        # Encoder
        self.encoder = Encoder(
            [
                EncoderLayer(
                    AutoCorrelationLayer(
                        AutoCorrelation(False, configs.factor, attention_dropout=configs.dropout,
                                        output_attention=configs.output_attention),
                        configs.d_model, configs.n_heads),
                    configs.d_model,
                    configs.d_ff,
                    moving_avg=configs.moving_avg,
                    dropout=configs.dropout,
                    activation=configs.activation
                ) for l in range(configs.e_layers)
            ],
            norm_layer=my_Layernorm(configs.d_model)
        )

1️⃣ 跳到: Encoder forward → 调用 单个 EncoderLayer

EncoderLayer forward

            for attn_layer in self.attn_layers:
                # 将输入序列通过编码器层进行处理,包含自注意力计算、残差连接、序列分解和前馈网络操作
                # [B, L, d_model] -> 通过EncoderLayer(自注意力+残差+序列分解+前馈网络) -> [B, L, d_model], attn权重[B, n_heads, L, L]
                x, attn = attn_layer(x, attn_mask=attn_mask) 
                attns.append(attn)

EncoderLayer init? 

完整代码:

class Encoder(nn.Module):
    """
    Autoformer encoder
    """
    def __init__(self, attn_layers, conv_layers=None, norm_layer=None):
        super(Encoder, self).__init__()
        self.attn_layers = nn.ModuleList(attn_layers)
        self.conv_layers = nn.ModuleList(conv_layers) if conv_layers is not None else None
        self.norm = norm_layer

    def forward(self, x, attn_mask=None):
        # x形状: [B, L, d_model] - 输入序列特征 ; attn_mask形状: [B, L, L] 或 None - 注意力掩码矩阵
        attns = []

        if self.conv_layers is not None: 
            for attn_layer, conv_layer in zip(self.attn_layers, self.conv_layers):
                # [B, L, d_model] -> 自注意力操作 -> [B, L, d_model] + 注意力权重[B, n_heads, L, L]
                x, attn = attn_layer(x, attn_mask=attn_mask) 
                # [B, L, d_model] -> 卷积操作 -> [B, L, d_model]
                x = conv_layer(x) 
                attns.append(attn)
            # [B, L, d_model] -> 自注意力操作 -> [B, L, d_model] + 注意力权重[B, n_heads, L, L]
            x, attn = self.attn_layers[-1](x) 
            attns.append(attn)
        else:
            for attn_layer in self.attn_layers:
                # 将输入序列通过编码器层进行处理,包含自注意力计算、残差连接、序列分解和前馈网络操作
                # [B, L, d_model] -> 通过EncoderLayer(自注意力+残差+序列分解+前馈网络) -> [B, L, d_model], attn权重[B, n_heads, L, L]
                x, attn = attn_layer(x, attn_mask=attn_mask) 
                attns.append(attn)

        if self.norm is not None:
            x = self.norm(x)

        return x, attns

2️⃣ 跳到 EncoderLayer forward

class EncoderLayer(nn.Module):
    """
    Autoformer encoder layer with the progressive decomposition architecture
    """
    def __init__(self, attention, d_model, d_ff=None, moving_avg=25, dropout=0.1, activation="relu"):
        super(EncoderLayer, self).__init__()
        d_ff = d_ff or 4 * d_model
        self.attention = attention
        self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1, bias=False)
        self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1, bias=False)
        self.decomp1 = series_decomp(moving_avg)
        self.decomp2 = series_decomp(moving_avg)
        self.dropout = nn.Dropout(dropout)
        self.activation = F.relu if activation == "relu" else F.gelu

    def forward(self, x, attn_mask=None):

        # 使用自注意力机制处理输入序列x,提取特征间的依赖关系
        # 输入x形状[B, L, d_model],通过自注意力计算得到new_x[B, L, d_model]和注意力权重attn[B, n_heads, L, L]
        new_x, attn = self.attention(
            x, x, x,
            attn_mask=attn_mask
        )

完整代码:

class EncoderLayer(nn.Module):
    """
    Autoformer encoder layer with the progressive decomposition architecture
    """
    def __init__(self, attention, d_model, d_ff=None, moving_avg=25, dropout=0.1, activation="relu"):
        super(EncoderLayer, self).__init__()
        d_ff = d_ff or 4 * d_model
        self.attention = attention
        self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1, bias=False)
        self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1, bias=False)
        self.decomp1 = series_decomp(moving_avg)
        self.decomp2 = series_decomp(moving_avg)
        self.dropout = nn.Dropout(dropout)
        self.activation = F.relu if activation == "relu" else F.gelu

    def forward(self, x, attn_mask=None):

        # 使用自注意力机制处理输入序列x,提取特征间的依赖关系
        # 输入x形状[B, L, d_model],通过自注意力计算得到new_x[B, L, d_model]和注意力权重attn[B, n_heads, L, L]
        new_x, attn = self.attention(
            x, x, x,
            attn_mask=attn_mask
        )

        # 应用残差连接和dropout正则化,增强网络稳定性和鲁棒性
        # x[B, L, d_model] + dropout(new_x)[B, L, d_model] -> x[B, L, d_model]        
        x = x + self.dropout(new_x)

        # 使用序列分解提取季节性部分,丢弃趋势部分(_)
        # 输入x[B, L, d_model] -> 序列分解 -> 输出季节性分量x[B, L, d_model]和未使用的趋势分量_[B, L, d_model]
        x, _ = self.decomp1(x)

        # 创建x的副本用于前馈网络处理分支
        #  复制x[B, L, d_model] -> y[B, L, d_model]
        y = x

        # 通过第一个卷积层扩展特征维度,应用激活函数和dropout
        # y[B, L, d_model] -> 转置[B, d_model, L] -> 卷积[B, d_ff, L] -> 激活 -> dropout -> y[B, d_ff, L]
        y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1))))
        
        # 通过第二个卷积层映射回原始维度,应用dropout后转置回序列格式
        # y[B, d_ff, L] -> 卷积[B, d_model, L] -> dropout -> 转置 -> y[B, L, d_model]
        y = self.dropout(self.conv2(y).transpose(-1, 1))

        # 将原始特征x与变换后的特征y相加,再次应用序列分解提取季节性部分
        # (x + y)[B, L, d_model] -> 序列分解 -> 输出res[B, L, d_model]和未使用的趋势分量_[B, L, d_model]
        res, _ = self.decomp2(x + y)

        # 返回处理后的特征表示和注意力权重矩阵
        # 返回值: res[B, L, d_model], attn[B, n_heads, L, L]
        return res, attn

3️⃣

class AutoCorrelationLayer(nn.Module):
    def __init__(self, correlation, d_model, n_heads, d_keys=None,
                 d_values=None):
        super(AutoCorrelationLayer, self).__init__()

        d_keys = d_keys or (d_model // n_heads)
        d_values = d_values or (d_model // n_heads)

        self.inner_correlation = correlation
        self.query_projection = nn.Linear(d_model, d_keys * n_heads)
        self.key_projection = nn.Linear(d_model, d_keys * n_heads)
        self.value_projection = nn.Linear(d_model, d_values * n_heads)
        self.out_projection = nn.Linear(d_values * n_heads, d_model)
        self.n_heads = n_heads

    def forward(self, queries, keys, values, attn_mask):
        # 获取输入查询张量的批次大小和序列长度
        # 从queries [B, L, d_model]中提取维度信息B(批次大小)和L(查询序列长度)
        B, L, _ = queries.shape

        # 获取输入键张量的序列长度
        # 从keys [B, S, d_model]中提取S(键序列长度)
        _, S, _ = keys.shape

        # 获取注意力头数,H = 8 
        
        H = self.n_heads

        # 将查询向量投影到多头注意力空间并重塑维度
        # [B, L, d_model] -> 线性投影 -> [B, L, H*d_k] -> 重塑 -> [B, L, H, d_k]
        queries = self.query_projection(queries).view(B, L, H, -1)

        # 将键向量投影到多头注意力空间并重塑维度
        # [B, S, d_model] -> 线性投影 -> [B, S, H*d_k] -> 重塑 -> [B, S, H, d_k]
        keys = self.key_projection(keys).view(B, S, H, -1)

        # 将值向量投影到多头注意力空间并重塑维度
        # [B, S, d_model] -> 线性投影 -> [B, S, H*d_v] -> 重塑 -> [B, S, H, d_v]
        values = self.value_projection(values).view(B, S, H, -1)

        # 计算自相关注意力,聚合时间序列中的周期性依赖关系
        # 输入[B,L,H,d_k],[B,S,H,d_k],[B,S,H,d_v] -> 自相关计算 -> 输出out[B,L,H,d_v],attn[B,H,L,S]
        out, attn = self.inner_correlation(
            queries,
            keys,
            values,
            attn_mask
        )

4️⃣

class AutoCorrelation(nn.Module):
    """
    AutoCorrelation Mechanism with the following two phases:
    (1) period-based dependencies discovery
    (2) time delay aggregation
    This block can replace the self-attention family mechanism seamlessly.
    """
    def __init__(self, mask_flag=True, factor=1, scale=None, attention_dropout=0.1, output_attention=False):
        super(AutoCorrelation, self).__init__()
        self.factor = factor
        self.scale = scale
        self.mask_flag = mask_flag
        self.output_attention = output_attention
        self.dropout = nn.Dropout(attention_dropout)

    def time_delay_agg_training(self, values, corr):
        """
        SpeedUp version of Autocorrelation (a batch-normalization style design)
        This is for the training phase.
        """

        # 获取值张量的头数、通道数和序列
        # values[B,H,D,L] -> 提取形状参数 -> head(H), channel(D), length(L)
        head = values.shape[1]

        # 获取值张量的特征通道数
        # values[B,H,D,L] -> 提取形状参数 -> channel(D)
        channel = values.shape[2]

        # 获取序列长度,用于计算top_k和循环移位操作
        # values[B,H,D,L] -> 提取形状参数 -> length(L)
        length = values.shape[3]

        # find top k
        # 根据序列长度动态计算需要选择的顶部相关项数量
        # length(标量) -> 对数和乘法运算 -> top_k(标量)
        top_k = int(self.factor * math.log(length))

        # 计算相关系数在头维度和特征维度上的平均值
        # corr[B,H,E,L] -> 沿头和特征维度平均 -> mean_value[B,L]
        mean_value = torch.mean(torch.mean(corr, dim=1), dim=1)

        # 找出平均相关系数最高的top_k个时间点的索引
        # mean_value[B,L] -> 批次平均[L] -> 选择top_k个最大值 -> index[top_k]
        index = torch.topk(torch.mean(mean_value, dim=0), top_k, dim=-1)[1]

        # 对于每个选定索引,提取相应的相关系数,并组合为权重矩阵
        # mean_value[B,L]和index[top_k] -> 索引提取和堆叠 -> weights[B,top_k]
        weights = torch.stack([mean_value[:, index[i]] for i in range(top_k)], dim=-1)
        
        # update corr
        # 对权重应用softmax归一化,使其成为概率分布
        # weights[B,top_k] -> softmax操作 -> tmp_corr[B,top_k]
        tmp_corr = torch.softmax(weights, dim=-1)

        # aggregation
        # 创建值张量的临时引用,用于循环移位操作
        # values[B,H,D,L] -> 引用赋值 -> tmp_values[B,H,D,L]
        tmp_values = values

        # 初始化延迟聚合结果为零张量
        # 创建与values相同形状的零张量 -> delays_agg[B,H,D,L]
        delays_agg = torch.zeros_like(values).float()

        # 对每个选定的顶部相关索引进行处理
        # 循环迭代top_k次
        for i in range(top_k):
            # 根据当前索引循环移位值张量,模拟时间延迟效果
            # tmp_values[B,H,D,L] -> 循环移位操作 -> pattern[B,H,D,L]
            pattern = torch.roll(tmp_values, -int(index[i]), -1)

            # 将移位后的模式乘以对应的相关系数权重,并累加到结果中
            # tmp_corr[:,i][B] -> 维度扩展[B,1,1,1] -> 重复[B,H,D,L] * pattern[B,H,D,L] -> 加到delays_agg[B,H,D,L]
            delays_agg = delays_agg + pattern * \
                         (tmp_corr[:, i].unsqueeze(1).unsqueeze(1).unsqueeze(1).repeat(1, head, channel, length))
        # 返回最终的时间延迟聚合结果 delays_agg[B,H,D,L]
        return delays_agg

    def time_delay_agg_inference(self, values, corr):
        """
        SpeedUp version of Autocorrelation (a batch-normalization style design)
        This is for the inference phase.
        """
        batch = values.shape[0]
        head = values.shape[1]
        channel = values.shape[2]
        length = values.shape[3]
        # index init
        init_index = torch.arange(length).unsqueeze(0).unsqueeze(0).unsqueeze(0)\
            .repeat(batch, head, channel, 1).to(values.device)
        # find top k
        top_k = int(self.factor * math.log(length))
        mean_value = torch.mean(torch.mean(corr, dim=1), dim=1)
        weights, delay = torch.topk(mean_value, top_k, dim=-1)
        # update corr
        tmp_corr = torch.softmax(weights, dim=-1)
        # aggregation
        tmp_values = values.repeat(1, 1, 1, 2)
        delays_agg = torch.zeros_like(values).float()
        for i in range(top_k):
            tmp_delay = init_index + delay[:, i].unsqueeze(1).unsqueeze(1).unsqueeze(1).repeat(1, head, channel, length)
            pattern = torch.gather(tmp_values, dim=-1, index=tmp_delay)
            delays_agg = delays_agg + pattern * \
                         (tmp_corr[:, i].unsqueeze(1).unsqueeze(1).unsqueeze(1).repeat(1, head, channel, length))
        return delays_agg

    def time_delay_agg_full(self, values, corr):
        """
        Standard version of Autocorrelation
        """
        batch = values.shape[0]
        head = values.shape[1]
        channel = values.shape[2]
        length = values.shape[3]
        # index init
        init_index = torch.arange(length).unsqueeze(0).unsqueeze(0).unsqueeze(0)\
            .repeat(batch, head, channel, 1).to(values.device)
        # find top k
        top_k = int(self.factor * math.log(length))
        weights, delay = torch.topk(corr, top_k, dim=-1)
        # update corr
        tmp_corr = torch.softmax(weights, dim=-1)
        # aggregation
        tmp_values = values.repeat(1, 1, 1, 2)
        delays_agg = torch.zeros_like(values).float()
        for i in range(top_k):
            tmp_delay = init_index + delay[..., i].unsqueeze(-1)
            pattern = torch.gather(tmp_values, dim=-1, index=tmp_delay)
            delays_agg = delays_agg + pattern * (tmp_corr[..., i].unsqueeze(-1))
        return delays_agg

    def forward(self, queries, keys, values, attn_mask):
        # 提取张量的形状参数
        # queries[B,L,H,E] -> 提取形状 -> B(批次大小),L(查询序列长度),H(注意力头数),E(每头特征维度)
        B, L, H, E = queries.shape

        # 提取值张量的维度,特别是序列长度S和特征维度D
        # values[B,S,H,D] -> 提取形状 -> S(键/值序列长度),D(值向量维度)
        _, S, _, D = values.shape

         # 处理查询序列长度和键/值序列长度不匹配的情况 -> 序列长度对齐处理
        if L > S:

            # 创建零张量用于填充,确保键和值的序列长度与查询一致
            # queries[:, :(L-S), :] -> 创建相同形状的零张量 -> zeros[B,(L-S),H,E]
            zeros = torch.zeros_like(queries[:, :(L - S), :]).float()

            # 将零张量拼接到值张量末尾,扩展序列长度
            # [values[B,S,H,D], zeros[B,(L-S),H,E]] -> 沿序列维度拼接 -> values[B,L,H,D]
            values = torch.cat([values, zeros], dim=1)
            
            # 将零张量拼接到键张量末尾,扩展序列长度
            # [keys[B,S,H,E], zeros[B,(L-S),H,E]] -> 沿序列维度拼接 -> keys[B,L,H,E]
            keys = torch.cat([keys, zeros], dim=1)
        else:

            # 截断值张量,使其序列长度与查询一致
            # values[B,S,H,D] -> 切片操作 -> values[B,L,H,D]
            values = values[:, :L, :, :]

            # 截断键张量,使其序列长度与查询一致
            # keys[B,S,H,E] -> 切片操作 -> keys[B,L,H,E]
            keys = keys[:, :L, :, :]

        # period-based dependencies
        # 计算查询的频域表示,用于发现周期性依赖关系
        # [B,L,H,E] -> 维度重排[B,H,E,L] -> 实数FFT -> q_fft[B,H,E,(L//2+1)]复数张量
        q_fft = torch.fft.rfft(queries.permute(0, 2, 3, 1).contiguous(), dim=-1)
        # 计算键的频域表示,用于与查询进行频域相关计算
        # [B,L,H,E] -> 维度重排[B,H,E,L] -> 实数FFT -> k_fft[B,H,E,(L//2+1)]复数张量
        k_fft = torch.fft.rfft(keys.permute(0, 2, 3, 1).contiguous(), dim=-1)
        # 计算频域中的互相关,通过复数乘法实现
        # q_fft[B,H,E,(L//2+1)] * conj(k_fft)[B,H,E,(L//2+1)] -> 复数乘法 -> res[B,H,E,(L//2+1)]
        res = q_fft * torch.conj(k_fft)
        # 将互相关转回时域,获得时间相关系数
        # res[B,H,E,(L//2+1)] -> 逆实数FFT -> corr[B,H,E,L]
        corr = torch.fft.irfft(res, n=L, dim=-1)

        # time delay agg
        # 根据是否处于训练模式选择不同的时间延迟聚合函数  条件判断 -> 选择聚合函数
        if self.training:
            # 训练模式:使用训练专用的时间延迟聚合方法
            # values[B,L,H,D] -> 重排[B,H,D,L] -> 聚合 -> 重排回 -> V[B,L,H,D]
            V = self.time_delay_agg_training(values.permute(0, 2, 3, 1).contiguous(), corr).permute(0, 3, 1, 2)
        else:
            # 推理模式:使用推理专用的时间延迟聚合方法
            # values[B,L,H,D] -> 重排[B,H,D,L] -> 聚合 -> 重排回 -> V[B,L,H,D]
            V = self.time_delay_agg_inference(values.permute(0, 2, 3, 1).contiguous(), corr).permute(0, 3, 1, 2)

        # 根据配置决定是否输出注意力权重 条件判断 -> 选择返回值
        if self.output_attention:
            # 输出处理后的值张量和注意力相关系数
            # V[B,L,H,D], corr[B,H,E,L] -> 整理格式 -> (V[B,L,H,D], corr[B,L,H,E])
            return (V.contiguous(), corr.permute(0, 3, 1, 2))
        else:
            # 仅输出处理后的值张量,不输出注意力权重
            # V[B,L,H,D] -> 整理格式 -> (V[B,L,H,D], None)
            return (V.contiguous(), None)


class AutoCorrelationLayer(nn.Module):
    def __init__(self, correlation, d_model, n_heads, d_keys=None,
                 d_values=None):
        super(AutoCorrelationLayer, self).__init__()

        d_keys = d_keys or (d_model // n_heads)
        d_values = d_values or (d_model // n_heads)

        self.inner_correlation = correlation
        self.query_projection = nn.Linear(d_model, d_keys * n_heads)
        self.key_projection = nn.Linear(d_model, d_keys * n_heads)
        self.value_projection = nn.Linear(d_model, d_values * n_heads)
        self.out_projection = nn.Linear(d_values * n_heads, d_model)
        self.n_heads = n_heads

    def forward(self, queries, keys, values, attn_mask):
        # 获取输入查询张量的批次大小和序列长度
        # 从queries [B, L, d_model]中提取维度信息B(批次大小)和L(查询序列长度)
        B, L, _ = queries.shape

        # 获取输入键张量的序列长度
        # 从keys [B, S, d_model]中提取S(键序列长度)
        _, S, _ = keys.shape

        # 获取注意力头数,H = 8 
        
        H = self.n_heads

        # 将查询向量投影到多头注意力空间并重塑维度
        # [B, L, d_model] -> 线性投影 -> [B, L, H*d_k] -> 重塑 -> [B, L, H, d_k]
        queries = self.query_projection(queries).view(B, L, H, -1)

        # 将键向量投影到多头注意力空间并重塑维度
        # [B, S, d_model] -> 线性投影 -> [B, S, H*d_k] -> 重塑 -> [B, S, H, d_k]
        keys = self.key_projection(keys).view(B, S, H, -1)

        # 将值向量投影到多头注意力空间并重塑维度
        # [B, S, d_model] -> 线性投影 -> [B, S, H*d_v] -> 重塑 -> [B, S, H, d_v]
        values = self.value_projection(values).view(B, S, H, -1)

        # 计算自相关注意力,聚合时间序列中的周期性依赖关系
        # 输入[B,L,H,d_k],[B,S,H,d_k],[B,S,H,d_v] -> 自相关计算 -> 输出out[B,L,H,d_v],attn[B,H,L,S]
        out, attn = self.inner_correlation(
            queries,
            keys,
            values,
            attn_mask
        )

        # 将多头输出合并为单个张量
        # [B, L, H, d_v] -> 重塑 -> [B, L, H*d_v]
        out = out.view(B, L, -1)

        # 将合并的输出投影回原始维度空间并返回结果
        # out[B,L,H*d_v] -> 线性投影 -> [B,L,d_model], attn保持[B,H,L,S]
        return self.out_projection(out), attn

Encoder.forward(x)
  --> EncoderLayer.forward(x)
      --> AutoCorrelationLayer.forward(x,x,x)
          --> AutoCorrelation.forward(queries,keys,values)
              --> AutoCorrelation.time_delay_agg_training/inference()

各层功能与数据流

  1. Encoder

    • 高层容器,管理多个EncoderLayer
    • 按顺序调用各个EncoderLayer并收集注意力权重
    • 输入形状[B,L,d_model] → 输出形状[B,L,d_model]
  2. EncoderLayer

    • 实现"注意力+残差+分解+前馈网络"的完整处理流程
    • 首先调用自注意力机制处理输入
    • 然后应用残差连接、序列分解和前馈网络
    • 输入形状[B,L,d_model] → 输出形状[B,L,d_model]
  3. AutoCorrelationLayer

    • 将输入投影到多头注意力空间
    • 调用内部的AutoCorrelation实现自相关计算
    • 将多头结果合并并投影回原始维度
    • 输入形状[B,L,d_model] → 投影 → [B,L,H,d_k/v] → 输出形状[B,L,d_model]
  4. AutoCorrelation

    • Autoformer的核心创新,实现基于FFT的自相关注意力
    • 通过频域计算发现周期性依赖关系
    • 根据相关系数聚合值向量
    • 输入形状[B,L,H,d_k/v] → FFT处理 → 时间延迟聚合 → 输出形状[B,L,H,d_v]

这种设计允许模型有效捕获时间序列中的周期性模式和长期依赖关系,而不依赖传统的点积注意力机制。

初始化Encoder

        self.encoder = Encoder(
            [
                EncoderLayer(
                    AutoCorrelationLayer(AutoCorrelation(False, configs.factor, attention_dropout=configs.dropout,output_attention=configs.output_attention),configs.d_model, configs.n_heads),
                    configs.d_model,
                    configs.d_ff,
                    moving_avg=configs.moving_avg,
                    dropout=configs.dropout,
                    activation=configs.activation
                ) for l in range(configs.e_layers)
            ],
            norm_layer=my_Layernorm(configs.d_model)
        )
class EncoderLayer(nn.Module):
    """
    Autoformer encoder layer with the progressive decomposition architecture
    """
    def __init__(self, attention, d_model, d_ff=None, moving_avg=25, dropout=0.1, activation="relu"):
        super(EncoderLayer, self).__init__()
        d_ff = d_ff or 4 * d_model
        self.attention = attention
        self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1, bias=False)
        self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1, bias=False)
        self.decomp1 = series_decomp(moving_avg)
        self.decomp2 = series_decomp(moving_avg)
        self.dropout = nn.Dropout(dropout)
        self.activation = F.relu if activation == "relu" else F.gelu
class AutoCorrelationLayer(nn.Module):
    def __init__(self, correlation, d_model, n_heads, d_keys=None,
                 d_values=None):
        super(AutoCorrelationLayer, self).__init__()

        d_keys = d_keys or (d_model // n_heads)
        d_values = d_values or (d_model // n_heads)

        self.inner_correlation = correlation
        self.query_projection = nn.Linear(d_model, d_keys * n_heads)
        self.key_projection = nn.Linear(d_model, d_keys * n_heads)
        self.value_projection = nn.Linear(d_model, d_values * n_heads)
        self.out_projection = nn.Linear(d_values * n_heads, d_model)
        self.n_heads = n_heads
class AutoCorrelation(nn.Module):
    """
    AutoCorrelation Mechanism with the following two phases:
    (1) period-based dependencies discovery
    (2) time delay aggregation
    This block can replace the self-attention family mechanism seamlessly.
    """
    def __init__(self, mask_flag=True, factor=1, scale=None, attention_dropout=0.1, output_attention=False):
        super(AutoCorrelation, self).__init__()
        self.factor = factor
        self.scale = scale
        self.mask_flag = mask_flag
        self.output_attention = output_attention
        self.dropout = nn.Dropout(attention_dropout)

Autoformer Encoder 

输入:x [B,L,d_model]
      |
      v
+----------------+
|    Encoder     |  包含多个EncoderLayer的堆叠
+----------------+
      |
      v
+----------------+       ┌─────────────────────────┐
| EncoderLayer   | ----> │ 1. 自注意力 + 残差连接   │
+----------------+       │ 2. 序列分解(提取季节性)  │
      |                 │ 3. 前馈网络(卷积实现)    │
      |                 │ 4. 残差连接 + 序列分解   │
      v                 └─────────────────────────┘
+----------------+
| AutoCorrelation|  实现基于FFT的自相关计算
+----------------+
      |
      v
输出:x [B,L,d_model], attns

EncoderLayer内部处理流程
 

              输入 x [B,L,d_model]
                    |
                    v
┌───────────────────────────────────┐
│   AutoCorrelation注意力机制       │
└───────────────────────────────────┘
                    |
                    v
         x + dropout(new_x) (残差连接)
                    |
                    v
┌───────────────────────────────────┐
│   series_decomp1(序列分解)        │
│   提取季节性成分(x),丢弃趋势(_)   │
└───────────────────────────────────┘
                    |
                    v
┌───────────────────────────────────┐
│   前馈网络(使用1D卷积实现)         │
│   1. 扩展维度: d_model → d_ff     │
│   2. 激活 + dropout              │
│   3. 缩减维度: d_ff → d_model     │
└───────────────────────────────────┘
                    |
                    v
          x + y (第二次残差连接)
                    |
                    v
┌───────────────────────────────────┐
│   series_decomp2(序列分解)        │
│   提取季节性成分(res),丢弃趋势(_) │
└───────────────────────────────────┘
                    |
                    v
          输出: res, attn

关键组件功能

  1. AutoCorrelation: 核心创新点,使用FFT计算自相关,捕获周期性模式

    • 特点:计算复杂度低于传统注意力,更适合长时间序列
  2. series_decomp: 将时间序列分解为季节性和趋势两部分

    • 使用移动平均(moving_avg)提取趋势
    • 原序列减去趋势得到季节性成分
    • Autoformer主要保留和处理季节性成分
  3. 前馈网络: 使用两个1D卷积层实现,类似Transformer的前馈网络

    • 先扩展特征维度: d_model → d_ff
    • 再压缩回原始维度: d_ff → d_model
AutoCorrelationLayer.forward(queries, keys, values, attn_mask)
    ├── 投影到多头注意力空间
    │   ├── queries -> [B, L, H, d_k]
    │   ├── keys -> [B, S, H, d_k]
    │   └── values -> [B, S, H, d_v]
    └── 调用AutoCorrelation.forward(queries, keys, values, attn_mask)
        ├── 计算频域自相关
        │   ├── queries -> FFT -> q_fft
        │   ├── keys -> FFT -> k_fft
        │   └── q_fft * conj(k_fft) -> res -> IFFT -> corr
        └── time_delay_agg聚合
            ├── 训练模式: time_delay_agg_training(values, corr)
            └── 推理模式: time_delay_agg_inference(values, corr)
    └── 合并多头输出并投影回原始维度
        └── out -> [B, L, d_model]

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值