Informer源码分析

1 数据准备

首先是数据准备阶段的入口函数,位于Exp_Informer类的train函数内

train_data, train_loader = self._get_data(flag = 'train')

self._get_data的实现如下,该函数主要就是根据所选择的数据集加载数据,之后构建DataSet和DataLoader:

def _get_data(self, flag):
    args = self.args

    data_dict = {
        'ETTh1':Dataset_ETT_hour,
        'ETTh2':Dataset_ETT_hour,
        'ETTm1':Dataset_ETT_minute,
        'ETTm2':Dataset_ETT_minute,
        'WTH':Dataset_Custom,
        'ECL':Dataset_Custom,
        'Solar':Dataset_Custom,
        'custom':Dataset_Custom,
    }
    Data = data_dict[self.args.data]
    # 时间特征编码格式
    timeenc = 0 if args.embed!='timeF' else 1

    if flag == 'test':
        shuffle_flag = False; drop_last = True; batch_size = args.batch_size; freq=args.freq
    elif flag=='pred':
        shuffle_flag = False; drop_last = False; batch_size = 1; freq=args.detail_freq
        Data = Dataset_Pred
    else:
        shuffle_flag = True; drop_last = True; batch_size = args.batch_size; freq=args.freq
    data_set = Data(
        root_path=args.root_path,
        data_path=args.data_path,
        flag=flag,
        size=[args.seq_len, args.label_len, args.pred_len],
        features=args.features,
        target=args.target,
        inverse=args.inverse,
        timeenc=timeenc,
        freq=freq,
        cols=args.cols
    )
    print(flag, len(data_set))
    data_loader = DataLoader(
        data_set,
        batch_size=batch_size,
        shuffle=shuffle_flag,
        num_workers=args.num_workers,
        drop_last=drop_last)

    return data_set, data_loader

数据集的加载可以按照不同的时间粒度进行构建,这里以Dataset_ETT_hour类为例子,其__init__函数如下

def __init__(self, root_path, flag='train', size=None, 
             features='S', data_path='ETTh1.csv', 
             target='OT', scale=True, inverse=False, timeenc=0, freq='h', cols=None):
    # size [seq_len, label_len, pred_len]
    # info
    if size == None:
        self.seq_len = 24*4*4
        self.label_len = 24*4
        self.pred_len = 24*4
    else:
        self.seq_len = size[0] #输入到informer编码器中序列的长度
        self.label_len = size[1] #输入到解码器中start token的序列长度
        self.pred_len = size[2] #需要预测的序列的长度
    # init
    assert flag in ['train', 'test', 'val']
    type_map = {'train':0, 'val':1, 'test':2}
    self.set_type = type_map[flag]
    
    self.features = features
    self.target = target # 要预测的目标维度特征
    self.scale = scale
    self.inverse = inverse
    self.timeenc = timeenc
    self.freq = freq #数据粒度
    
    self.root_path = root_path #数据集所在路径
    self.data_path = data_path #数据集文件名称
    self.__read_data__()

在初始化时最重要的函数就是_read_data_

def __read_data__(self):
    self.scaler = StandardScaler()
    # 加载csv数据为DataFrame
    df_raw = pd.read_csv(os.path.join(self.root_path,
                                      self.data_path))

    border1s = [0, 12*30*24 - self.seq_len, 12*30*24+4*30*24 - self.seq_len]
    border2s = [12*30*24, 12*30*24+4*30*24, 12*30*24+8*30*24]
    # 根据flag是train,val,test来选择加载的数据的起始于结束的位置
    # {'train':0, 'val':1, 'test':2}
    border1 = border1s[self.set_type]
    border2 = border2s[self.set_type]
    
    # 根据预测的类型进一步提取数据
    # M : multivariate predict multivariate, S : univariate predict univariate, MS : multivariate predict univariate
    if self.features=='M' or self.features=='MS':
        cols_data = df_raw.columns[1:] #df_raw.columns[0]是date
        df_data = df_raw[cols_data]
    elif self.features=='S':
        df_data = df_raw[[self.target]]
	# 如果需要进行归一化,则通过StandardScaler进行归一化处理
    if self.scale:
        train_data = df_data[border1s[0]:border2s[0]]
        self.scaler.fit(train_data.values)
        data = self.scaler.transform(df_data.values)
    else:
        data = df_data.values #将数据转化为numpy数组
        
    df_stamp = df_raw[['date']][border1:border2] #取出指定范围内的序列数据的'date'列
    df_stamp['date'] = pd.to_datetime(df_stamp.date)
    # 通过time_features对date数据进行进一步的处理
    data_stamp = time_features(df_stamp, timeenc=self.timeenc, freq=self.freq)
    
    self.data_x = data[border1:border2] #需要使用归一化后的data
    if self.inverse:
        self.data_y = df_data.values[border1:border2] #不需要归一化后的数据
    else:
        self.data_y = data[border1:border2] #如果self.inverse是False, data_x和data_y就完全相同
    self.data_stamp = data_stamp

1.1 时间的处理

在上面read_data中需要详细了解的是time_features函数,该函数的实现如下:

def time_features(dates, timeenc=1, freq='h'):
    """
    > `time_features` takes in a `dates` dataframe with a 'dates' column and extracts the date down to `freq` where freq can be any of the following if `timeenc` is 0: 
    > * m - [month]
    > * w - [month]
    > * d - [month, day, weekday]
    > * b - [month, day, weekday]
    > * h - [month, day, weekday, hour]
    > * t - [month, day, weekday, hour, *minute]
    > 
    > If `timeenc` is 1, a similar, but different list of `freq` values are supported (all encoded between [-0.5 and 0.5]): 
    > * Q - [month]
    > * M - [month]
    > * W - [Day of month, week of year]
    > * D - [Day of week, day of month, day of year]
    > * B - [Day of week, day of month, day of year]
    > * H - [Hour of day, day of week, day of month, day of year]
    > * T - [Minute of hour*, hour of day, day of week, day of month, day of year]
    > * S - [Second of minute, minute of hour, hour of day, day of week, day of month, day of year]

    *minute returns a number from 0-3 corresponding to the 15 minute period it falls into.
    """
    # timeenc = 0 if args.embed!='timeF' else 1
    if timeenc==0:
        dates['month'] = dates.date.apply(lambda row:row.month,1)
        dates['day'] = dates.date.apply(lambda row:row.day,1)
        dates['weekday'] = dates.date.apply(lambda row:row.weekday(),1)
        dates['hour'] = dates.date.apply(lambda row:row.hour,1)
        dates['minute'] = dates.date.apply(lambda row:row.minute,1)
        dates['minute'] = dates.minute.map(lambda x:x//15) #将分钟划分为4个离散值,每15分钟一个值
        freq_map = {
            'y':[],'m':['month'],'w':['month'],'d':['month','day','weekday'],
            'b':['month','day','weekday'],'h':['month','day','weekday','hour'],
            't':['month','day','weekday','hour','minute'],
        }
        return dates[freq_map[freq.lower()]].values
    if timeenc==1:
        dates = pd.to_datetime(dates.date.values)
        # time_features_from_frequency_str返回的是对应freq下的时间归一化对象列表,处理后会将时间按照需要的格式归一化到-0.5到0.5之间
        return np.vstack([feat(dates) for feat in time_features_from_frequency_str(freq)]).transpose(1,0)

time_features_from_frequency_str的实现如下:

def time_features_from_frequency_str(freq_str: str) -> List[TimeFeature]:
    """
    Returns a list of time features that will be appropriate for the given frequency string.
    Parameters
    ----------
    freq_str
        Frequency string of the form [multiple][granularity] such as "12H", "5min", "1D" etc.
    """

    features_by_offsets = {
        offsets.YearEnd: [],
        offsets.QuarterEnd: [MonthOfYear],
        offsets.MonthEnd: [MonthOfYear],
        offsets.Week: [DayOfMonth, WeekOfYear],
        offsets.Day: [DayOfWeek, DayOfMonth, DayOfYear],
        offsets.BusinessDay: [DayOfWeek, DayOfMonth, DayOfYear],
        offsets.Hour: [HourOfDay, DayOfWeek, DayOfMonth, DayOfYear],
        offsets.Minute: [
            MinuteOfHour,
            HourOfDay,
            DayOfWeek,
            DayOfMonth,
            DayOfYear,
        ],
        offsets.Second: [
            SecondOfMinute,
            MinuteOfHour,
            HourOfDay,
            DayOfWeek,
            DayOfMonth,
            DayOfYear,
        ],
    }
	# pandas官网解释:https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.tseries.frequencies.to_offset.html
    offset = to_offset(freq_str)

    for offset_type, feature_classes in features_by_offsets.items():
        if isinstance(offset, offset_type):
            return [cls() for cls in feature_classes] #创建每个时间归一化对象

    supported_freq_msg = f"""
    Unsupported frequency {freq_str}
    The following frequencies are supported:
        Y   - yearly
            alias: A
        M   - monthly
        W   - weekly
        D   - daily
        B   - business days
        H   - hourly
        T   - minutely
            alias: min
        S   - secondly
    """
    raise RuntimeError(supported_freq_msg)

接下来回过来继续分析Dataset_ETT_hour类的函数_getitem_

def __getitem__(self, index):
    s_begin = index #起始下标值
    s_end = s_begin + self.seq_len #输入编码器的序列的结束下标值
    r_begin = s_end - self.label_len #start token起始下标值
    r_end = r_begin + self.label_len + self.pred_len #解码器输出的序列结束的下标值

    seq_x = self.data_x[s_begin:s_end]
    if self.inverse:
        seq_y = np.concatenate([self.data_x[r_begin:r_begin+self.label_len], self.data_y[r_begin+self.label_len:r_end]], 0)
    else:
        seq_y = self.data_y[r_begin:r_end]
    seq_x_mark = self.data_stamp[s_begin:s_end]
    seq_y_mark = self.data_stamp[r_begin:r_end]

    return seq_x, seq_y, seq_x_mark, seq_y_mark

2 训练阶段

def train(self, setting):
    train_data, train_loader = self._get_data(flag = 'train')
    vali_data, vali_loader = self._get_data(flag = 'val')
    test_data, test_loader = self._get_data(flag = 'test')
	
    # 生成用于存储checkpoint的文件路径
    path = os.path.join(self.args.checkpoints, setting)
    if not os.path.exists(path):
        os.makedirs(path)

    time_now = time.time()
    
    train_steps = len(train_loader)
    early_stopping = EarlyStopping(patience=self.args.patience, verbose=True)
    
    model_optim = self._select_optimizer() #Adam
    criterion =  self._select_criterion() #MSELoss
	
    if self.args.use_amp:
        scaler = torch.cuda.amp.GradScaler() #automatic-mixed-precision 模型进行轻量化

    for epoch in range(self.args.train_epochs):
        iter_count = 0
        train_loss = []
        
        self.model.train() #声明模型当前处于训练状态,BatchNorm或Dropout都需要开启
        epoch_time = time.time()
        for i, (batch_x,batch_y,batch_x_mark,batch_y_mark) in enumerate(train_loader):
            iter_count += 1
            
            model_optim.zero_grad()
            pred, true = self._process_one_batch(
                train_data, batch_x, batch_y, batch_x_mark, batch_y_mark)
            loss = criterion(pred, true)
            train_loss.append(loss.item())
            
            if (i+1) % 100==0:
                print("\titers: {0}, epoch: {1} | loss: {2:.7f}".format(i + 1, epoch + 1, loss.item()))
                speed = (time.time()-time_now)/iter_count
                left_time = speed*((self.args.train_epochs - epoch)*train_steps - i)
                print('\tspeed: {:.4f}s/iter; left time: {:.4f}s'.format(speed, left_time))
                iter_count = 0
                time_now = time.time()
            
            if self.args.use_amp:
                scaler.scale(loss).backward()
                scaler.step(model_optim)
                scaler.update()
            else:
                loss.backward()
                model_optim.step()

        print("Epoch: {} cost time: {}".format(epoch+1, time.time()-epoch_time))
        train_loss = np.average(train_loss)
        vali_loss = self.vali(vali_data, vali_loader, criterion)
        test_loss = self.vali(test_data, test_loader, criterion)

        print("Epoch: {0}, Steps: {1} | Train Loss: {2:.7f} Vali Loss: {3:.7f} Test Loss: {4:.7f}".format(
            epoch + 1, train_steps, train_loss, vali_loss, test_loss))
        early_stopping(vali_loss, self.model, path)
        if early_stopping.early_stop:
            print("Early stopping")
            break

        adjust_learning_rate(model_optim, epoch+1, self.args)
        
    best_model_path = path+'/'+'checkpoint.pth'
    self.model.load_state_dict(torch.load(best_model_path))
    
    return self.model

这里核心是函数_process_one_batch,其实现如下:

def _process_one_batch(self, dataset_object, batch_x, batch_y, batch_x_mark, batch_y_mark):
    batch_x = batch_x.float().to(self.device)
    batch_y = batch_y.float()

    batch_x_mark = batch_x_mark.float().to(self.device)
    batch_y_mark = batch_y_mark.float().to(self.device)

    # decoder input (以下的输入序列的构造应该是在填充需要预测的mask序列的值为固定内容0或1)
    if self.args.padding==0:
        # 这里的维度[batch_size, pred_len, feature_dim]
        dec_inp = torch.zeros([batch_y.shape[0], self.args.pred_len, batch_y.shape[-1]]).float()
    elif self.args.padding==1:
        dec_inp = torch.ones([batch_y.shape[0], self.args.pred_len, batch_y.shape[-1]]).float() #需要预测的序列端用全1代替
    dec_inp = torch.cat([batch_y[:,:self.args.label_len,:], dec_inp], dim=1).float().to(self.device) #将start token部分与需要预测的序列部分合并
    # encoder - decoder
    if self.args.use_amp:
        with torch.cuda.amp.autocast():
            if self.args.output_attention:
                outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)[0]
            else:
                outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)
    else:
        if self.args.output_attention:
            outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)[0]
        else:
            outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)
    if self.args.inverse:
        outputs = dataset_object.inverse_transform(outputs) #将输出的数据反归一化
    f_dim = -1 if self.args.features=='MS' else 0 #如果预测是‘MS’的单输出就是在最后一个位置,否则就是全部所有的维度的特征
    batch_y = batch_y[:,-self.args.pred_len:,f_dim:].to(self.device)

    return outputs, batch_y

在了解具体的训练逻辑之前,我们需要看一下模型是在什么时候初始化的。模型的初始化是在Exp_Informer类初始话的时候完成的,其中调用了函数_build_model。Exp_Informer继承自Exp_Basic类,Exp_Basic类的定义如下:

class Exp_Basic(object):
    def __init__(self, args):
        self.args = args
        self.device = self._acquire_device()
        self.model = self._build_model().to(self.device) #调用_build_model完成模型架构的搭建并放入指定设备中运算

    def _build_model(self):
        raise NotImplementedError
        return None
    
    def _acquire_device(self):
        if self.args.use_gpu:
            os.environ["CUDA_VISIBLE_DEVICES"] = str(self.args.gpu) if not self.args.use_multi_gpu else self.args.devices
            device = torch.device('cuda:{}'.format(self.args.gpu))
            print('Use GPU: cuda:{}'.format(self.args.gpu))
        else:
            device = torch.device('cpu')
            print('Use CPU')
        return device

    def _get_data(self):
        pass

    def vali(self):
        pass

    def train(self):
        pass

    def test(self):
        pass

下面是在Exp_Informer中实现的_build_model函数的详细代码

def _build_model(self):
    model_dict = {
        'informer':Informer,
        'informerstack':InformerStack,
    }
    if self.args.model=='informer' or self.args.model=='informerstack':
        e_layers = self.args.e_layers if self.args.model=='informer' else self.args.s_layers
        model = model_dict[self.args.model](
            self.args.enc_in, #编码器输入的大小
            self.args.dec_in, #解码器输入的大小
            self.args.c_out,
            self.args.seq_len,
            self.args.label_len,
            self.args.pred_len, 
            self.args.factor, #Probsparse attn factor 应该是采样系数
            self.args.d_model, #模型的维度(可能是embeding层和self-attention输出的维度)
            self.args.n_heads, #注意力头的数量
            e_layers, # self.args.e_layers,编码器layer的数量
            self.args.d_layers, #解码器layer的数量
            self.args.d_ff, #全连接层维度的大小
            self.args.dropout, #dropout的概率
            self.args.attn, #self-attention是使用informer的还是transformer的
            self.args.embed, #时间特征编码的方式
            self.args.freq, #时间编码时的粒度(s:secondly, t:minutely, h:hourly, d:daily, b:business days, w:weekly, m:monthly).You can also use more detailed freq like 15min or 3h
            self.args.activation, #激活函数的选择
            self.args.output_attention, #表示在inner_attention即FullAttention和ProbAttention中的forward()中是否要将attention的softmax结果返回
            self.args.distil, #是否进行蒸馏操作
            self.args.mix,
            self.device
        ).float()
    
    if self.args.use_multi_gpu and self.args.use_gpu:
        model = nn.DataParallel(model, device_ids=self.args.device_ids)
    return model

3 模型架构

3.1 Informer架构

以下时Informer的初始化过程

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-x11hnMVn-1654163034088)(C:\work\Note\img\encoderStructure.png)]

def __init__(self, enc_in, dec_in, c_out, seq_len, label_len, out_len, 
            factor=5, d_model=512, n_heads=8, e_layers=3, d_layers=2, d_ff=512, 
            dropout=0.0, attn='prob', embed='fixed', freq='h', activation='gelu', 
            output_attention = False, distil=True, mix=True,
            device=torch.device('cuda:0')):
    super(Informer, self).__init__()
    self.pred_len = out_len
    self.attn = attn
    self.output_attention = output_attention

    # Encoding enc_in=7, dec_in=7, d_model=512 这里输入的大小为7应该是因为输入的数据每一个时刻由7维特征组成
    self.enc_embedding = DataEmbedding(enc_in, d_model, embed, freq, dropout)
    self.dec_embedding = DataEmbedding(dec_in, d_model, embed, freq, dropout)
    # Attention
    Attn = ProbAttention if attn=='prob' else FullAttention
    # Encoder
    self.encoder = Encoder(
        [
            EncoderLayer(
                AttentionLayer(Attn(False, factor, attention_dropout=dropout, output_attention=output_attention), 
                            d_model, n_heads, mix=False),
                d_model,
                d_ff,
                dropout=dropout,
                activation=activation
            ) for l in range(e_layers)
        ],
        [
            ConvLayer(
                d_model
            ) for l in range(e_layers-1)
        ] if distil else None,
        norm_layer=torch.nn.LayerNorm(d_model)
    )
    # Decoder
    self.decoder = Decoder(
        [
            DecoderLayer(
                AttentionLayer(Attn(True, factor, attention_dropout=dropout, output_attention=False), 
                            d_model, n_heads, mix=mix),
                AttentionLayer(FullAttention(False, factor, attention_dropout=dropout, output_attention=False), 
                            d_model, n_heads, mix=False),
                d_model,
                d_ff,
                dropout=dropout,
                activation=activation,
            )
            for l in range(d_layers)
        ],
        norm_layer=torch.nn.LayerNorm(d_model)
    )
    # self.end_conv1 = nn.Conv1d(in_channels=label_len+out_len, out_channels=out_len, kernel_size=1, bias=True)
    # self.end_conv2 = nn.Conv1d(in_channels=d_model, out_channels=c_out, kernel_size=1, bias=True)
    self.projection = nn.Linear(d_model, c_out, bias=True)
def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, 
            enc_self_mask=None, dec_self_mask=None, dec_enc_mask=None):
    enc_out = self.enc_embedding(x_enc, x_mark_enc)
    enc_out, attns = self.encoder(enc_out, attn_mask=enc_self_mask)

    dec_out = self.dec_embedding(x_dec, x_mark_dec)
    dec_out = self.decoder(dec_out, enc_out, x_mask=dec_self_mask, cross_mask=dec_enc_mask)
    dec_out = self.projection(dec_out)
    
    # dec_out = self.end_conv1(dec_out)
    # dec_out = self.end_conv2(dec_out.transpose(2,1)).transpose(1,2)
    if self.output_attention:
        return dec_out[:,-self.pred_len:,:], attns
    else:
        return dec_out[:,-self.pred_len:,:] # [B, L, D]

3.1.1 DataEmbedding

class DataEmbedding(nn.Module):
    def __init__(self, c_in, d_model, embed_type='fixed', freq='h', dropout=0.1):
        super(DataEmbedding, self).__init__()

        self.value_embedding = TokenEmbedding(c_in=c_in, d_model=d_model)
        self.position_embedding = PositionalEmbedding(d_model=d_model)
        #标准化后的时间才会使用TimeFeatureEmbedding,这是一个可学习的时间编码
        self.temporal_embedding = TemporalEmbedding(d_model=d_model, embed_type=embed_type, freq=freq) if embed_type!='timeF' else TimeFeatureEmbedding(d_model=d_model, embed_type=embed_type, freq=freq) 

        self.dropout = nn.Dropout(p=dropout)
	
    #这里x的输入维度应该是[batch_size, seq_len, dim_feature],x_mark的维度应该是[batch_size, seq_len, dim_date]
    def forward(self, x, x_mark):
        # 这里将三个embedding的结果相加,具体原因可以参考
        x = self.value_embedding(x) + self.position_embedding(x) + self.temporal_embedding(x_mark)
        
        return self.dropout(x)

从DataEmbedding的结构可以看出其中分别构建了tokenEmbedding、positionEmbedding、temporalEmbedding三个模块

3.1.1.1 tokenEmbedding
class TokenEmbedding(nn.Module):
    def __init__(self, c_in, d_model):
        super(TokenEmbedding, self).__init__()
        padding = 1 if torch.__version__>='1.5.0' else 2
        # nn.Conv1d对输入序列的每一个时刻的特征进行一维卷积,且这里stride使用默认的1
        self.tokenConv = nn.Conv1d(in_channels=c_in, out_channels=d_model, 
                                    kernel_size=3, padding=padding, padding_mode='circular')
        for m in self.modules():
            if isinstance(m, nn.Conv1d):
                nn.init.kaiming_normal_(m.weight,mode='fan_in',nonlinearity='leaky_relu')

    def forward(self, x):
        # https://pytorch.org/docs/master/generated/torch.nn.Conv1d.html#torch.nn.Conv1d
        # 因为Conv1d要求输入是(N, Cin, L)输出是(N, Cout, L),所以需要对输入样本维度顺序进行调整
        x = self.tokenConv(x.permute(0, 2, 1)).transpose(1,2)  
        return x
3.1.1.2 positionEmbedding
class PositionalEmbedding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super(PositionalEmbedding, self).__init__()
        # Compute the positional encodings once in log space.
        pe = torch.zeros(max_len, d_model).float() #创建出了5000个位置的编码,但可能并不需要5000个长度的编码
        pe.require_grad = False

        position = torch.arange(0, max_len).float().unsqueeze(1) # 生成维度为[5000, 1]的位置下标向量
        div_term = (torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model)).exp()

        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)

        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)
        
    def forward(self, x):
        return self.pe[:, :x.size(1)]
3.1.1.3 TemporalEmbedding
class TemporalEmbedding(nn.Module):
    def __init__(self, d_model, embed_type='fixed', freq='h'):
        super(TemporalEmbedding, self).__init__()

        minute_size = 4; hour_size = 24
        weekday_size = 7; day_size = 32; month_size = 13

        Embed = FixedEmbedding if embed_type=='fixed' else nn.Embedding
        if freq=='t':
            self.minute_embed = Embed(minute_size, d_model)
        self.hour_embed = Embed(hour_size, d_model)
        self.weekday_embed = Embed(weekday_size, d_model)
        self.day_embed = Embed(day_size, d_model)
        self.month_embed = Embed(month_size, d_model)
        
   def forward(self, x):
        x = x.long()
        # 在数据准备阶段,对于时间的处理时若freq=‘h’时'h':['month','day','weekday','hour']
        minute_x = self.minute_embed(x[:,:,4]) if hasattr(self, 'minute_embed') else 0.
        hour_x = self.hour_embed(x[:,:,3])
        weekday_x = self.weekday_embed(x[:,:,2])
        day_x = self.day_embed(x[:,:,1])
        month_x = self.month_embed(x[:,:,0])
        
        return hour_x + weekday_x + day_x + month_x + minute_x
class FixedEmbedding(nn.Module):
    def __init__(self, c_in, d_model):# c_in表示有多少个位置,在时间编码中表示每一维时间特征的粒度(h:24, m:4, weekday:7, day:32, month:13)
        super(FixedEmbedding, self).__init__()

        w = torch.zeros(c_in, d_model).float()
        w.require_grad = False

        position = torch.arange(0, c_in).float().unsqueeze(1)
        div_term = (torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model)).exp()

        w[:, 0::2] = torch.sin(position * div_term)
        w[:, 1::2] = torch.cos(position * div_term)

        self.emb = nn.Embedding(c_in, d_model)
        self.emb.weight = nn.Parameter(w, requires_grad=False)

    def forward(self, x):
        return self.emb(x).detach() #不进行训练
3.1.1.4 TimeFeatureEmbedding
class TimeFeatureEmbedding(nn.Module):
    def __init__(self, d_model, embed_type='timeF', freq='h'):
        super(TimeFeatureEmbedding, self).__init__()

        freq_map = {'h':4, 't':5, 's':6, 'm':1, 'a':1, 'w':2, 'd':3, 'b':3}
        d_inp = freq_map[freq]
        self.embed = nn.Linear(d_inp, d_model)
    
    def forward(self, x):
        return self.embed(x)

3.1.2 self-attention

在分析Informer的ProbAttention之前我们先来分析一下原始self-Attention的源码,如下代码是Attention机制的整体架构,其中d_keys和d_values的维度与d_model和n_heads(注意力头)相关,AttentionLayer中的forward过程主要做的事情就是将embedding的输入序列映射到n_heads个注意力头,并通过具体的inner_attention来完成自注意力的过程:

class AttentionLayer(nn.Module):
    def __init__(self, attention, d_model, n_heads, 
                 d_keys=None, d_values=None, mix=False):
        super(AttentionLayer, self).__init__()

        d_keys = d_keys or (d_model//n_heads) #如果d_model=512并且采用默认n_heads=8时,d_keys=64
        d_values = d_values or (d_model//n_heads)

        self.inner_attention = attention # FullAttention or ProbAttention
        # https://pytorch.org/docs/master/generated/torch.nn.Linear.html?highlight=nn%20linear#torch.nn.Linear
        # 由官方对nn.Linear的介绍可知,全连接只针对最后一维特征进行全连接
        self.query_projection = nn.Linear(d_model, d_keys * n_heads) 
        self.key_projection = nn.Linear(d_model, d_keys * n_heads)
        self.value_projection = nn.Linear(d_model, d_values * n_heads)
        self.out_projection = nn.Linear(d_values * n_heads, d_model)
        self.n_heads = n_heads
        self.mix = mix #Informer类和InformerStack类中是False
        
    def forward(self, queries, keys, values, attn_mask):
        B, L, _ = queries.shape 
        _, S, _ = keys.shape #这里S与L的维度应该相同才对
        H = self.n_heads

        queries = self.query_projection(queries).view(B, L, H, -1)
        keys = self.key_projection(keys).view(B, S, H, -1)
        values = self.value_projection(values).view(B, S, H, -1)

        out, attn = self.inner_attention(
            queries,
            keys,
            values,
            attn_mask
        )
        if self.mix:
            out = out.transpose(2,1).contiguous()
        out = out.view(B, L, -1) #out的维度应该是[batch_size, seq_len, d_values*n_heads]

        return self.out_projection(out), attn #out_projection前向过程结束后张量维度应该是[batch_size, seq_len, d_model]
3.1.2.1 transformer attention

下面分析一下FullAttention,即原始的Transformer中的self-attention的具体源码

class FullAttention(nn.Module):
    def __init__(self, mask_flag=True, factor=5, scale=None, attention_dropout=0.1, output_attention=False):
        super(FullAttention, self).__init__()
        self.scale = scale
        self.mask_flag = mask_flag
        self.output_attention = output_attention
        self.dropout = nn.Dropout(attention_dropout)
        
    def forward(self, queries, keys, values, attn_mask):
        B, L, H, E = queries.shape # batch_size, seq_len, head_num, dim_feature 
        _, S, _, D = values.shape
        scale = self.scale or 1./sqrt(E) #默认是没有1/sqrt(d)

        scores = torch.einsum("blhe,bshe->bhls", queries, keys)#score的维度应该是[batch_size, head_num, seq_len, seq_len]
        if self.mask_flag: #默认mask_flag是True,attn_mask是None
            if attn_mask is None:
                attn_mask = TriangularCausalMask(B, L, device=queries.device) #获取mask的上三角矩阵(去除了对角线上的值)得到的mask矩阵的维度为[batch_size, 1, seq_len, seq_len]

            scores.masked_fill_(attn_mask.mask, -np.inf) #将attn_mask中的mask矩阵所有值为1的位置都设置成-np.inf,且具有广播机制

        A = self.dropout(torch.softmax(scale * scores, dim=-1)) # 为什么要进行dropout?
        V = torch.einsum("bhls,bshd->blhd", A, values)
        
        if self.output_attention:
            return (V.contiguous(), A)
        else:
            return (V.contiguous(), None)

其中TriangularCausalMask类的主要功能是给每一个batch的序列打上mask,因为输入某一时刻的特征是无法知道未来时刻的特征的

class TriangularCausalMask():
    def __init__(self, B, L, device="cpu"):
        mask_shape = [B, 1, L, L]
        with torch.no_grad():
            self._mask = torch.triu(torch.ones(mask_shape, dtype=torch.bool), diagonal=1).to(device)

    @property
    def mask(self):
        return self._mask
3.1.2.2 Informer attention
class ProbAttention(nn.Module):
    def __init__(self, mask_flag=True, factor=5, scale=None, attention_dropout=0.1, output_attention=False):
        super(ProbAttention, self).__init__()
        self.factor = factor
        self.scale = scale
        self.mask_flag = mask_flag
        self.output_attention = output_attention
        self.dropout = nn.Dropout(attention_dropout)

首先我们来看一下ProbAttention的forward函数:

def forward(self, queries, keys, values, attn_mask):
    B, L_Q, H, D = queries.shape
    _, L_K, _, _ = keys.shape #如论文所述,这里的L_Q=L_K=L

    queries = queries.transpose(2,1) #将queries转换为B, H, L, D
    keys = keys.transpose(2,1)
    values = values.transpose(2,1)
	#这里最后用.item()将结果转为纯python int类型
    U_part = self.factor * np.ceil(np.log(L_K)).astype('int').item() # c*ln(L_k)
    u = self.factor * np.ceil(np.log(L_Q)).astype('int').item() # c*ln(L_q) 
	
    U_part = U_part if U_part<L_K else L_K
    u = u if u<L_Q else L_Q
    # scores_top:[B, H, n_top, L_K], index:[B, H, n_top]
    scores_top, index = self._prob_QK(queries, keys, sample_k=U_part, n_top=u) 

    # add scale factor
    scale = self.scale or 1./sqrt(D)
    if scale is not None:
        scores_top = scores_top * scale
    # get the context
    context = self._get_initial_context(values, L_Q)
    # update the context with selected top_k queries
    # self.output_attention是False时返回的attn也就是None, context:[B, H, L_Q, D]
    context, attn = self._update_context(context, values, scores_top, index, L_Q, attn_mask) 
    # context.transpose(2,1).contiguous()的维度是[B, L_Q, H, D]
    return context.transpose(2,1).contiguous(), attn

上面的代码中涉及到几个重要的函数实现,我们接下来依次进行解析,首先是函数_prob_QK,该函数返回经过筛选后的query和key内积后的结果,以及筛选出的u-top个query的index:

def _prob_QK(self, Q, K, sample_k, n_top): # n_top: c*ln(L_q)
    # Q [B, H, L, D]
    B, H, L_K, E = K.shape
    _, _, L_Q, _ = Q.shape

    # calculate the sampled Q_K
    # 这里扩展出来一维L_Q,表示每一个query都有L_K个对应的key,且每个key是长度为E的向量 k_expand的维度是[B, H, L_Q, L_K, E]
    K_expand = K.unsqueeze(-3).expand(B, H, L_Q, L_K, E) 
    # index_sample的维度是[L_Q, sample_k]
    index_sample = torch.randint(L_K, (L_Q, sample_k)) # real U = U_part(factor*ln(L_k))*L_q 为每一个query采样sampe_k个key的index
    # K_sample的维度是[B, H, L_Q, sample_k, E]
    # torch.arange(L_Q).unsqueeze(1)这句话相当于在L_Q的这个维度之后加了一个维度
    K_sample = K_expand[:, :, torch.arange(L_Q).unsqueeze(1), index_sample, :]
    # Q.unsqueeze(-2)后的维度是[B, H, L_Q, 1, D], K_sample.transpose(-2, -1)后的维度是[B, H, L_Q, E, sample_k],此处D与E的长度应该相同
    # torch.matmul的计算应该可以这样理解,由于D=E(源码中应该是64),所以将[1, D]的张量与[D, sample_k]的张量做矩阵相乘,然后删除1的维度则最终
    # Q_K_sample的维度就是[B, H, L_Q, sample_k], 含义为每一个query与采样下来的sample_k个key内积后的attention结果
    Q_K_sample = torch.matmul(Q.unsqueeze(-2), K_sample.transpose(-2, -1)).squeeze(-2) # 这里的squeeze需要加上-2如果不加的话,在运行batch_size=1时会出现问题

    # find the Top_k query with sparisty measurement
    # torch.div(Q_K_sample.sum(-1), L_K)计算后的维度是[B, H, L_Q], Q_K_sample.max(-1)[0]计算后的维度是[B, H, L_Q]
    M = Q_K_sample.max(-1)[0] - torch.div(Q_K_sample.sum(-1), L_K)
    # M_top的维度是[B, H, n_top]
    M_top = M.topk(n_top, sorted=False)[1]

    # use the reduced Q to calculate Q_K
    # Q_reduce的维度是[B, H, n_top, D]
    Q_reduce = Q[torch.arange(B)[:, None, None],
                 torch.arange(H)[None, :, None],
                 M_top, :] # factor*ln(L_q)
    # Q_k的维度是[B, H, n_top, L_K]
    Q_K = torch.matmul(Q_reduce, K.transpose(-2, -1)) # factor*ln(L_q)*L_k

    return Q_K, M_top

我们这里来逐行的分析_prob_QK函数的过程,首先我们造一个假的输入数据,K的维度是(1, 2, 4, 6),Q与K相同:

K = torch.linspace(1, 48, steps=48).resize(1, 2, 4, 6)
Q = torch.linspace(1, 48, steps=48).resize(1, 2, 4, 6)
'''
Q与K相同:(1, 2, 4, 6)其中1-24的数据属于head-1,25-48的数据属于head-2
tensor([[[[ 1.,  2.,  3.,  4.,  5.,  6.],
          [ 7.,  8.,  9., 10., 11., 12.],
          [13., 14., 15., 16., 17., 18.],
          [19., 20., 21., 22., 23., 24.]],

         [[25., 26., 27., 28., 29., 30.],
          [31., 32., 33., 34., 35., 36.],
          [37., 38., 39., 40., 41., 42.],
          [43., 44., 45., 46., 47., 48.]]]])
'''
K_unsqueeze = K.unsqueeze(-3)
'''
K_unsqueeze:(1, 2, 1, 4, 6)
tensor([[[[[ 1.,  2.,  3.,  4.,  5.,  6.],
           [ 7.,  8.,  9., 10., 11., 12.],
           [13., 14., 15., 16., 17., 18.],
           [19., 20., 21., 22., 23., 24.]]],


         [[[25., 26., 27., 28., 29., 30.],
           [31., 32., 33., 34., 35., 36.],
           [37., 38., 39., 40., 41., 42.],
           [43., 44., 45., 46., 47., 48.]]]]])
'''
K_expand = K_unsqueeze.expand(B, H, L_Q, L_K, E)
'''
K_expand:(1, 2, 4, 4, 6)相当于在倒数第三维上将最后两维的数据复制了四遍
tensor([[[[[ 1.,  2.,  3.,  4.,  5.,  6.],
           [ 7.,  8.,  9., 10., 11., 12.],
           [13., 14., 15., 16., 17., 18.],
           [19., 20., 21., 22., 23., 24.]],

          [[ 1.,  2.,  3.,  4.,  5.,  6.],
           [ 7.,  8.,  9., 10., 11., 12.],
           [13., 14., 15., 16., 17., 18.],
           [19., 20., 21., 22., 23., 24.]],

          [[ 1.,  2.,  3.,  4.,  5.,  6.],
           [ 7.,  8.,  9., 10., 11., 12.],
           [13., 14., 15., 16., 17., 18.],
           [19., 20., 21., 22., 23., 24.]],

          [[ 1.,  2.,  3.,  4.,  5.,  6.],
           [ 7.,  8.,  9., 10., 11., 12.],
           [13., 14., 15., 16., 17., 18.],
           [19., 20., 21., 22., 23., 24.]]],


         [[[25., 26., 27., 28., 29., 30.],
           [31., 32., 33., 34., 35., 36.],
           [37., 38., 39., 40., 41., 42.],
           [43., 44., 45., 46., 47., 48.]],

          [[25., 26., 27., 28., 29., 30.],
           [31., 32., 33., 34., 35., 36.],
           [37., 38., 39., 40., 41., 42.],
           [43., 44., 45., 46., 47., 48.]],

          [[25., 26., 27., 28., 29., 30.],
           [31., 32., 33., 34., 35., 36.],
           [37., 38., 39., 40., 41., 42.],
           [43., 44., 45., 46., 47., 48.]],

          [[25., 26., 27., 28., 29., 30.],
           [31., 32., 33., 34., 35., 36.],
           [37., 38., 39., 40., 41., 42.],
           [43., 44., 45., 46., 47., 48.]]]]])
'''
index_sample = torch.randint(L_K, (L_Q, sample_k)) # (4, 2)
'''
index_sample:(4, 2)
tensor([[3, 3],
        [3, 0],
        [2, 3],
        [0, 3]])
'''
K_tmp_id = torch.arange(L_Q).unsqueeze(1) # (4, 1)
'''
K_tmp_id:(4, 1)
tensor([[0],
        [1],
        [2],
        [3]])
'''
K_sample = K_expand[:, :, K_tmp_id, index_sample, :]
'''
K_sample:(1, 2, 4, 2, 6)
tensor([[[[[19., 20., 21., 22., 23., 24.],
           [19., 20., 21., 22., 23., 24.]],head-1对应K_tmp_id[0], index_sample的[3, 3]

          [[19., 20., 21., 22., 23., 24.],
           [ 1.,  2.,  3.,  4.,  5.,  6.]],head-1对应K_tmp_id[1], index_sample的[3, 0]

          [[13., 14., 15., 16., 17., 18.],
           [19., 20., 21., 22., 23., 24.]],head-1对应K_tmp_id[2], index_sample的[2, 3]

          [[ 1.,  2.,  3.,  4.,  5.,  6.],
           [19., 20., 21., 22., 23., 24.]]],head-1对应K_tmp_id[3], index_sample的[0, 3]


         [[[43., 44., 45., 46., 47., 48.],
           [43., 44., 45., 46., 47., 48.]],head-2对应K_tmp_id[0], index_sample的[3, 3]

          [[43., 44., 45., 46., 47., 48.],
           [25., 26., 27., 28., 29., 30.]],head-2对应K_tmp_id[1], index_sample的[3, 0]

          [[37., 38., 39., 40., 41., 42.],
           [43., 44., 45., 46., 47., 48.]],head-2对应K_tmp_id[2], index_sample的[2, 3]

          [[25., 26., 27., 28., 29., 30.],
           [43., 44., 45., 46., 47., 48.]]]]])head-2对应K_tmp_id[3], index_sample的[0, 3]
'''
Q_unsqueeze = Q.unsqueeze(-2)
'''
Q_unsqueeze:(1, 2, 4, 1, 6)
tensor([[[[[ 1.,  2.,  3.,  4.,  5.,  6.]],

          [[ 7.,  8.,  9., 10., 11., 12.]],

          [[13., 14., 15., 16., 17., 18.]],

          [[19., 20., 21., 22., 23., 24.]]],


         [[[25., 26., 27., 28., 29., 30.]],

          [[31., 32., 33., 34., 35., 36.]],

          [[37., 38., 39., 40., 41., 42.]],

          [[43., 44., 45., 46., 47., 48.]]]]])
'''
K_sample_trans = K_sample.transpose(-2, -1)
'''
K_sample_trans:(1, 2, 4, 6, 2)
tensor([[[[[19., 19.],
           [20., 20.],
           [21., 21.],
           [22., 22.],
           [23., 23.],
           [24., 24.]],

          [[19.,  1.],
           [20.,  2.],
           [21.,  3.],
           [22.,  4.],
           [23.,  5.],
           [24.,  6.]],

          [[13., 19.],
           [14., 20.],
           [15., 21.],
           [16., 22.],
           [17., 23.],
           [18., 24.]],

          [[ 1., 19.],
           [ 2., 20.],
           [ 3., 21.],
           [ 4., 22.],
           [ 5., 23.],
           [ 6., 24.]]],


         [[[43., 43.],
           [44., 44.],
           [45., 45.],
           [46., 46.],
           [47., 47.],
           [48., 48.]],

          [[43., 25.],
           [44., 26.],
           [45., 27.],
           [46., 28.],
           [47., 29.],
           [48., 30.]],

          [[37., 43.],
           [38., 44.],
           [39., 45.],
           [40., 46.],
           [41., 47.],
           [42., 48.]],

          [[25., 43.],
           [26., 44.],
           [27., 45.],
           [28., 46.],
           [29., 47.],
           [30., 48.]]]]])
'''
Q_K_sample_nonsqueeze = torch.matmul(Q_unsqueeze, K_sample_trans)
'''
Q_unsqueeze:(1, 2, 4, 1, 6)
tensor([[[[[ 1.,  2.,  3.,  4.,  5.,  6.]],

          [[ 7.,  8.,  9., 10., 11., 12.]],

          [[13., 14., 15., 16., 17., 18.]],

          [[19., 20., 21., 22., 23., 24.]]],


         [[[25., 26., 27., 28., 29., 30.]],

          [[31., 32., 33., 34., 35., 36.]],

          [[37., 38., 39., 40., 41., 42.]],

          [[43., 44., 45., 46., 47., 48.]]]]])
 K_sample_trans:(1, 2, 4, 6, 2)      
 tensor([[[[[19., 19.],
           [20., 20.],
           [21., 21.],
           [22., 22.],
           [23., 23.],
           [24., 24.]],

          [[19.,  1.],
           [20.,  2.],
           [21.,  3.],
           [22.,  4.],
           [23.,  5.],
           [24.,  6.]],

          [[13., 19.],
           [14., 20.],
           [15., 21.],
           [16., 22.],
           [17., 23.],
           [18., 24.]],

          [[ 1., 19.],
           [ 2., 20.],
           [ 3., 21.],
           [ 4., 22.],
           [ 5., 23.],
           [ 6., 24.]]],


         [[[43., 43.],
           [44., 44.],
           [45., 45.],
           [46., 46.],
           [47., 47.],
           [48., 48.]],

          [[43., 25.],
           [44., 26.],
           [45., 27.],
           [46., 28.],
           [47., 29.],
           [48., 30.]],

          [[37., 43.],
           [38., 44.],
           [39., 45.],
           [40., 46.],
           [41., 47.],
           [42., 48.]],

          [[25., 43.],
           [26., 44.],
           [27., 45.],
           [28., 46.],
           [29., 47.],
           [30., 48.]]]]])
          
Q_K_sample_nonsqueeze:(1, 2, 4, 1, 2)  
tensor([[[[[  469.,   469.]],

          [[ 1243.,   217.]],

          [[ 1459.,  2017.]],

          [[  469.,  2791.]]],


         [[[ 7525.,  7525.]],

          [[ 9163.,  5545.]],

          [[ 9379., 10801.]],

          [[ 7525., 12439.]]]]])
'''
Q_K_sample = Q_K_sample_nonsqueeze.squeeze(-2)
'''
Q_K_sample:(1, 2, 4, 2)
tensor([[[[  469.,   469.],
          [ 1243.,   217.],
          [ 1459.,  2017.],
          [  469.,  2791.]],

         [[ 7525.,  7525.],
          [ 9163.,  5545.],
          [ 9379., 10801.],
          [ 7525., 12439.]]]])
'''
# find the Top_k query with sparisty measurement
Q_K_sample_sum = Q_K_sample.sum(-1) #这一步就是完成了所有query与抽样出来的key进行内积的过程
'''
Q_K_sample_sum:(1, 2, 4)
tensor([[[  938.,  1460.,  3476.,  3260.],
         [15050., 14708., 20180., 19964.]]])
'''
Q_K_sample_max = Q_K_sample.max(-1)
'''
torch.return_types.max(
values=tensor([[[  469.,  1243.,  2017.,  2791.],
         [ 7525.,  9163., 10801., 12439.]]]),
indices=tensor([[[0, 0, 1, 1],
         [0, 0, 1, 1]]]))'''
div_tmp = torch.div(Q_K_sample_sum, L_K)
'''
div_tmp:(1, 2, 4)
tensor([[[ 234.5000,  365.0000,  869.0000,  815.0000],
         [3762.5000, 3677.0000, 5045.0000, 4991.0000]]])
'''
# 这一步就是在计算论文中的公式(4)
M = Q_K_sample_max[0] - torch.div(Q_K_sample_sum, L_K)  # (32, 8, 12)
'''
M:(1, 2, 4)
tensor([[[ 234.5000,  878.0000, 1148.0000, 1976.0000],
         [3762.5000, 5486.0000, 5756.0000, 7448.0000]]])
'''
# 选择其中的top_u个序列
M_top_tmp = M.topk(n_top, sorted=False)
'''
torch.return_types.topk(
values=tensor([[[1976., 1148.],
         [7448., 5756.]]]),
indices=tensor([[[3, 2],
         [3, 2]]]))
'''
M_top = M_top_tmp[1]  # (32, 8, 12)
'''
M_top:(1, 2, n_top:2)
tensor([[[3, 2],
         [3, 2]]])
'''
# use the reduced Q to calculate Q_K
Q_reduce_B = torch.arange(B)[:, None, None]
'''
Q_reduce_B:(1, 1, 1)
tensor([[[0]]])
'''
Q_reduce_H = torch.arange(H)[None, :, None]
'''
Q_reduce_H:(1, 2, 1)
tensor([[[0],
         [1]]])
'''
#到这里就得到了论文中公式(3)的Q^
Q_reduce = Q[Q_reduce_B,
             Q_reduce_H,
             M_top, :]  # factor*ln(L_q)
'''
Q:(1, 2, 4, 6)
tensor([[[[ 1.,  2.,  3.,  4.,  5.,  6.],
          [ 7.,  8.,  9., 10., 11., 12.],
          [13., 14., 15., 16., 17., 18.],
          [19., 20., 21., 22., 23., 24.]],

         [[25., 26., 27., 28., 29., 30.],
          [31., 32., 33., 34., 35., 36.],
          [37., 38., 39., 40., 41., 42.],
          [43., 44., 45., 46., 47., 48.]]]])


Q_reduce:(1, head:2, n_top:2, dim:6) 这里拿到的就是Q的head-1中位于坐标[3, 2]和head-2中位于坐标[3, 2]位置的query
tensor([[[[19., 20., 21., 22., 23., 24.],
          [13., 14., 15., 16., 17., 18.]],

         [[43., 44., 45., 46., 47., 48.],
          [37., 38., 39., 40., 41., 42.]]]])
'''
Q_K = torch.matmul(Q_reduce, K.transpose(-2, -1))  # factor*ln(L_q)*L_k
'''
Q_K:(1, 2, n_top:2, 4)
tensor([[[[  343.,   901.,  1459.,  2017.],
          [  469.,  1243.,  2017.,  2791.]],

         [[ 6535.,  7957.,  9379., 10801.],
          [ 7525.,  9163., 10801., 12439.]]]])
'''
return Q_K, M_top

通过上面代码的逐行打印分析可以发现,其实最终计算返回的Q_K就是计算论文中公式(3)的 Q ‾ K T \overline{Q}K^T QKT
Λ ( Q , K , V ) = S o f t m a x ( Q ‾ K T d ) V \Lambda(Q,K,V)=Softmax(\frac{\overline{Q}K^T}{\sqrt{d}})V Λ(Q,K,V)=Softmax(d QKT)V
下面分析函数_get_initial_context,该函数的主要作用就是将V按照倒数第二维度进行均值计算,并扩展复制到多个head的维度:

def _get_initial_context(self, V, L_Q):
    B, H, L_V, D = V.shape
    if not self.mask_flag:
        # V_sum = V.sum(dim=-2)
        V_sum = V.mean(dim=-2) #V_sum的维度是[B, H, D]即得到所有维度特征的mean,这里之所以取mean应该也是和论文中解释没有任何信息提供的attention就和平均分布一样,所以在论文中将attention后的结果与平均分布进行KL散度的距离计算
        contex = V_sum.unsqueeze(-2).expand(B, H, L_Q, V_sum.shape[-1]).clone() #contex的维度是[B, H, L_Q, D]即将每一个V_sum的最后一个特征维度复制了L_Q遍
    else: # use mask
        assert(L_Q == L_V) # requires that L_Q == L_V, i.e. for self-attention only
        contex = V.cumsum(dim=-2)
    return contex

我们将上述函数的实现展开一步一步推测其做的工作,输入的V与上一步的Q和K相同:

B, H, L_V, D = V.shape # 这里V的维度与Q和K相同
'''
V:(1, 2, 4, 6)
tensor([[[[ 1.,  2.,  3.,  4.,  5.,  6.],
          [ 7.,  8.,  9., 10., 11., 12.],
          [13., 14., 15., 16., 17., 18.],
          [19., 20., 21., 22., 23., 24.]],

         [[25., 26., 27., 28., 29., 30.],
          [31., 32., 33., 34., 35., 36.],
          [37., 38., 39., 40., 41., 42.],
          [43., 44., 45., 46., 47., 48.]]]])
'''
if not self.mask_flag: # 源码中mask_flag是False, 因此只会走下面的分支
    V_sum = V.mean(dim=-2) # 序列长度为4的value的均值
    '''
    V_sum:(1, 2, 6)
    tensor([[[10., 11., 12., 13., 14., 15.],
         [34., 35., 36., 37., 38., 39.]]])
    '''
    V_sum_unsequeese = V_sum.unsqueeze(-2)
    '''
    V_sum_unsequeese:(1, 2, 1, 6)
    tensor([[[[10., 11., 12., 13., 14., 15.]],

         [[34., 35., 36., 37., 38., 39.]]]])
    '''
    contex = V_sum_unsequeese.expand(B, H, L_Q, V_sum.shape[-1]).clone()
    '''
    contex:(1, 2, 4, 6)
    tensor([[[[10., 11., 12., 13., 14., 15.],
          [10., 11., 12., 13., 14., 15.],
          [10., 11., 12., 13., 14., 15.],
          [10., 11., 12., 13., 14., 15.]],

         [[34., 35., 36., 37., 38., 39.],
          [34., 35., 36., 37., 38., 39.],
          [34., 35., 36., 37., 38., 39.],
          [34., 35., 36., 37., 38., 39.]]]])
    '''
else: # use mask
    assert(L_Q == L_V) # requires that L_Q == L_V, i.e. for self-attention only
    contex = V.cumsum(dim=-2)

之后_update_context的工作就是根据论文中的公式计算ProbSparse self-attention,其中 Q ‾ \overline{Q} Q表示选择出的top-u个query所组成的新Q矩阵
A ( Q ; K ; V ) = S o f t m a x ( Q ‾ K T d ) V A(Q;K;V)=Softmax(\frac{\overline{Q}K^{T}}{\sqrt{d}})V A(Q;K;V)=Softmax(d QKT)V

# scores:[B, H, n_top, L_K], index:[B, H, n_top]
def _update_context(self, context_in, V, scores, index, L_Q, attn_mask):
    B, H, L_V, D = V.shape

    if self.mask_flag:
        attn_mask = ProbMask(B, H, L_Q, index, scores, device=V.device)
        scores.masked_fill_(attn_mask.mask, -np.inf)
	# attn的维度是[B, H, n_top, L_K]其中L_K = L_V
    attn = torch.softmax(scores, dim=-1) # nn.Softmax(dim=-1)(scores)
	# context_in维度是[B, H, L_Q, D]其中L_Q这个维度只有index指定的n_top个位置的值被更新了
    # 之所以这样做是因为在进行query和key的内积时是随机采样后再选择n_top个query与所有的key进行内积的
    # 所以最终在与value进行矩阵乘时也只有n_top个选中的query才有输出的更新,其余的在代码实现中用替代前的均值替代
    context_in[torch.arange(B)[:, None, None],	
               torch.arange(H)[None, :, None],
               index, :] = torch.matmul(attn, V).type_as(context_in)
    if self.output_attention:
        attns = (torch.ones([B, H, L_V, L_V])/L_V).type_as(attn).to(attn.device) 
        attns[torch.arange(B)[:, None, None], torch.arange(H)[None, :, None], index, :] = attn
        return (context_in, attns)
    else:
        return (context_in, None)

我们将上面的函数展开以此来查看具体做了些什么操作:

B, H, L_V, D = V.shape # (1, 2, 4, 6)

if self.mask_flag:
	attn_mask = ProbMask(B, H, L_Q, index, scores, device=V.device)
    scores.masked_fill_(attn_mask.mask, -np.inf)

attn = torch.softmax(scores, dim=-1) #(1, 2, 2, 4) nn.Softmax(dim=-1)(scores)
'''
scores:(1, 2, 2, 4)
tensor([[[[ 140.0292,  367.8317,  595.6343,  823.4368],
          [ 191.4685,  507.4526,  823.4368, 1139.4210]],

         [[2667.9026, 3248.4319, 3828.9609, 4409.4897],
          [3072.0686, 3740.7793, 4409.4897, 5078.2007]]]])

attn:(1, 2, n_top:2, 4)
tensor([[[[0., 0., 0., 1.],
          [0., 0., 0., 1.]],

         [[0., 0., 0., 1.],
          [0., 0., 0., 1.]]]])
'''
attn_V = torch.matmul(attn, V).type_as(context_in)
'''
V:(1, 2, 4, 6)
tensor([[[[ 1.,  2.,  3.,  4.,  5.,  6.],
          [ 7.,  8.,  9., 10., 11., 12.],
          [13., 14., 15., 16., 17., 18.],
          [19., 20., 21., 22., 23., 24.]],

         [[25., 26., 27., 28., 29., 30.],
          [31., 32., 33., 34., 35., 36.],
          [37., 38., 39., 40., 41., 42.],
          [43., 44., 45., 46., 47., 48.]]]])

attn_V:(1, 2, 2, 6)
tensor([[[[19., 20., 21., 22., 23., 24.],
          [19., 20., 21., 22., 23., 24.]],

         [[43., 44., 45., 46., 47., 48.],
          [43., 44., 45., 46., 47., 48.]]]])
'''
context_in_B = torch.arange(B)[:, None, None] #(1, 1, 1)
context_in_H = torch.arange(H)[None, :, None] #(1, 2, 1)
context_in[context_in_B, context_in_H, index, :] = attn_V
'''
index:(1, 2, 2)
tensor([[[3, 2],
         [3, 2]]])
->index_t:(1, 2, 2, D)
tensor([[[[3., 3., 3., 3., 3., 3.],
          [2., 2., 2., 2., 2., 2.]],

         [[3., 3., 3., 3., 3., 3.],
          [2., 2., 2., 2., 2., 2.]]]])
          
 tensor([[[[10., 11., 12., 13., 14., 15.],
          [10., 11., 12., 13., 14., 15.],
          [19., 20., 21., 22., 23., 24.],
          [19., 20., 21., 22., 23., 24.]],

         [[34., 35., 36., 37., 38., 39.],
          [34., 35., 36., 37., 38., 39.],
          [43., 44., 45., 46., 47., 48.],
          [43., 44., 45., 46., 47., 48.]]]])

context_in old:(1, 2, 4, 6)
tensor([[[[10., 11., 12., 13., 14., 15.],
          [10., 11., 12., 13., 14., 15.],
          [10., 11., 12., 13., 14., 15.],
          [10., 11., 12., 13., 14., 15.]],

         [[34., 35., 36., 37., 38., 39.],
          [34., 35., 36., 37., 38., 39.],
          [34., 35., 36., 37., 38., 39.],
          [34., 35., 36., 37., 38., 39.]]]])
         

context_in new:(1, 2, 4, 6)
tensor([[[[10., 11., 12., 13., 14., 15.],
          [10., 11., 12., 13., 14., 15.],
          [19., 20., 21., 22., 23., 24.],
          [19., 20., 21., 22., 23., 24.]],

         [[34., 35., 36., 37., 38., 39.],
          [34., 35., 36., 37., 38., 39.],
          [43., 44., 45., 46., 47., 48.],
          [43., 44., 45., 46., 47., 48.]]]])
'''

if self.output_attention:
    attns = (torch.ones([B, H, L_V, L_V])/L_V).type_as(attn).to(attn.device)
    attns[torch.arange(B)[:, None, None], torch.arange(H)[None, :, None], index, :] = attn
    return (context_in, attns)
else:
	return (context_in, None)

3.1.3 Encoder

class Encoder(nn.Module):
    def __init__(self, attn_layers, conv_layers=None, norm_layer=None):
        super(Encoder, self).__init__()
        self.attn_layers = nn.ModuleList(attn_layers) # EncoderLayer List
        self.conv_layers = nn.ModuleList(conv_layers) if conv_layers is not None else None
        self.norm = norm_layer

    def forward(self, x, attn_mask=None):
        # x [B, L, D]
        attns = [] #记录每层attention的结果
        if self.conv_layers is not None:
            for attn_layer, conv_layer in zip(self.attn_layers, self.conv_layers):
                x, attn = attn_layer(x, attn_mask=attn_mask) #x:[batch_size, seq_len, d_model]
                x = conv_layer(x) #进行Self-attention distilling来减小内存的占用 x:[batch_size, seq_len/2, d_model]第一次经过conv_layer时
                attns.append(attn)
            x, attn = self.attn_layers[-1](x, attn_mask=attn_mask)
            attns.append(attn)
        else:
            for attn_layer in self.attn_layers:
                # x:[batch_size, seq_len, d_model]
                x, attn = attn_layer(x, attn_mask=attn_mask)
                attns.append(attn)

        if self.norm is not None:
            x = self.norm(x)

        return x, attns

这里将attn_layers和conv_layers的内容也打印在下面:

#attn_layers 
[
     EncoderLayer(
         AttentionLayer(Attn(False, factor, attention_dropout=dropout, output_attention=output_attention), 
                        d_model, n_heads, mix=False),
         d_model,
         d_ff,
         dropout=dropout,
         activation=activation
     ) for l in range(e_layers)
  ] # e_layer:Num of encoder layers (defaults to 2)
#conv_layers 
[
    ConvLayer(
        d_model
    ) for l in range(e_layers-1)
] if distil else None
3.1.3.1 EncoderLayer

下面详细展示EncoderLayer和ConvLayer的详细实现:

class EncoderLayer(nn.Module):
    def __init__(self, attention, d_model, d_ff=None, dropout=0.1, activation="relu"):
        super(EncoderLayer, self).__init__()
        d_ff = d_ff or 4*d_model
        self.attention = attention #AttentionLayer
        self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1) #d_ff依赖self.args.d_ff含义是Dimension of fcn (defaults to 2048)
        self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1)
        self.norm1 = nn.LayerNorm(d_model) #LayerNorm取同一样本的不同通道进行归一化
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
        self.activation = F.relu if activation == "relu" else F.gelu
	
    # attention layer->skip layer操作->LayerNorm->MLP(conv1d)->skip layer操作->LayerNorm
    def forward(self, x, attn_mask=None):
        # x [B, L, D]
        # x = x + self.dropout(self.attention(
        #     x, x, x,
        #     attn_mask = attn_mask
        # ))
        new_x, attn = self.attention(
            x, x, x,
            attn_mask = attn_mask
        ) #new_x:[batch_size, seq_len, d_model]
        # 这里有一个skip layer的操作
        x = x + self.dropout(new_x)

        y = x = self.norm1(x)
        # 两层的1维卷积操作
        y = self.dropout(self.activation(self.conv1(y.transpose(-1,1))))
        y = self.dropout(self.conv2(y).transpose(-1,1))
        
        return self.norm2(x+y), attn
3.1.3.2 ConvLayer

ConvLayer的结构就是在计算论文中的distill的公式,其中 [ . ] A B [.]_{AB} [.]AB表示attention block:
X j + 1 t = M a x P o o l ( E L U ( C o n v 1 d ( [ X j t ] A B ) ) ) X_{j+1}^{t}=MaxPool(ELU(Conv1d([X_j^t]_{AB}))) Xj+1t=MaxPool(ELU(Conv1d([Xjt]AB)))

class ConvLayer(nn.Module):
    # c_in的维度应该与d_model=512相同
    def __init__(self, c_in):
        super(ConvLayer, self).__init__()
        padding = 1 if torch.__version__>='1.5.0' else 2
        self.downConv = nn.Conv1d(in_channels=c_in,
                                  out_channels=c_in,
                                  kernel_size=3,
                                  padding=padding,
                                  padding_mode='circular')
        self.norm = nn.BatchNorm1d(c_in)
        self.activation = nn.ELU()
        self.maxPool = nn.MaxPool1d(kernel_size=3, stride=2, padding=1)

    def forward(self, x):
        # x:[batch_size, seq_len, d_model]
        x = self.downConv(x.permute(0, 2, 1))
        x = self.norm(x)
        x = self.activation(x)
        x = self.maxPool(x) #经过maxPool操作后,x:[batch_size, d_model, seq_len/2]
        x = x.transpose(1,2) #第一次经过conv_layer时,返回结果的维度是[batch_size, seq_len/2, d_model]
        return x

3.1.4 Decoder

class Decoder(nn.Module):
    def __init__(self, layers, norm_layer=None):
        super(Decoder, self).__init__()
        self.layers = nn.ModuleList(layers) #DecoderLayer List
        self.norm = norm_layer #LayerNorm

    def forward(self, x, cross, x_mask=None, cross_mask=None):
        for layer in self.layers:
            x = layer(x, cross, x_mask=x_mask, cross_mask=cross_mask)

        if self.norm is not None:
            x = self.norm(x)

        return x

我们把decoder在Informer中的初始化放在下方,可以看到Decoder中每一层Decoder Layer中包含两个Attention Layer:

self.decoder = Decoder(
    [
        DecoderLayer(
            AttentionLayer(Attn(True, factor, attention_dropout=dropout, output_attention=False), 
                        d_model, n_heads, mix=mix)#self_attention,
            AttentionLayer(FullAttention(False, factor, attention_dropout=dropout, output_attention=False), 
                        d_model, n_heads, mix=False)#cross_attention,
            d_model,
            d_ff,
            dropout=dropout,
            activation=activation,
        )
        for l in range(d_layers)
    ],
    norm_layer=torch.nn.LayerNorm(d_model)
)
3.1.4.1 DecoderLayer
class DecoderLayer(nn.Module):
    def __init__(self, self_attention, cross_attention, d_model, d_ff=None,
                 dropout=0.1, activation="relu"):
        super(DecoderLayer, self).__init__()
        d_ff = d_ff or 4*d_model
        self.self_attention = self_attention
        self.cross_attention = cross_attention
        self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1)
        self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
        self.activation = F.relu if activation == "relu" else F.gelu #激活函数默认是gelu

    def forward(self, x, cross, x_mask=None, cross_mask=None):
        # ProbSparse Attention->DropOut->LayerNorm->Full Attention->Dropout->LayerNorm->MLP->LayerNorm
        x = x + self.dropout(self.self_attention(
            x, x, x,
            attn_mask=x_mask
        )[0])
        x = self.norm1(x)

        x = x + self.dropout(self.cross_attention(
            x, cross, cross,
            attn_mask=cross_mask
        )[0])

        y = x = self.norm2(x)
        y = self.dropout(self.activation(self.conv1(y.transpose(-1,1))))
        y = self.dropout(self.conv2(y).transpose(-1,1))

        return self.norm3(x+y)

4 模型部署落地

落地方案采用转换为MNN框架兼容格式onnx,之后通过MNN框架进行落地部署

4.1 将pytorch代码转换为onnx格式

将model转换为onnx进行持久化用到pytorch中的torch.onnx.export接口(官方网址:https://pytorch.org/docs/master/onnx.html?highlight=torch%20onnx%20export#torch.onnx.export),先看一个网上的例子:

import torch
import torch.nn as nn
import onnx
import numpy as np
class Model(nn.Module):
    def __init__(self):
        super(Model,self).__init__()
        self.conv1=nn.Conv2d(3,3, kernel_size=3, stride=2,padding=1)
    def forward(self,x,y):
        result1=self.conv1(x)
        result2=self.conv1(y)
        return result1,result2

model=Model()
model.eval() # 若存在batchnorm、dropout层则一定要eval()!!!!再export

input_names = ["input_0","input_1"]
output_names = ["output_0","output_1"]

x=torch.randn((1,3,12,12))
y=torch.randn((1,3,6,6))

torch.onnx.export(model,(x,y),'model.onnx',input_names=input_names,output_names=output_names,
  dynamic_axes={'input_0':[0],'output_0':[0]}) # 指定input_0和output_0的batch可变

需要先确定我们Informer模型的输入与输出的维度从而确定接口第二项参数的维度。

其中输入包括batch_x, batch_x_mark, dec_inp, batch_y_mark,输出包括outputs_app, outputs_user,其中在APP预测项目中各参数维度分别为:

  • batch_x:(32,12,64)
  • batch_x_mark:(32,12,5)
  • dec_inp:(32,7,64)
  • batch_y_mark:(32,7,5)
  • outputs_app:(32,1,1521)
  • outputs_user:(32,1,851)
评论 19
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值