临时代码,随便写写

该代码实现了一个时间序列预测模型,主要基于GPT-2模型,并结合了趋势(trend)、季节性(season)、噪声(noise)等时间序列分解的概念。下面是详细解释:

  1. 模块结构

    • ComplexLinear: 这是一个自定义的线性层,用于处理复数输入。该层将输入的实部和虚部分开处理,并将其组合为复数输出。
    • MultiFourier: 用于执行多项傅里叶变换,目的是捕捉时间序列中的周期性成分。
    • moving_avg: 这是一个移动平均模块,用于平滑时间序列数据,从而突出趋势成分。
    • TEMPO: 主模型类,它整合了GPT-2预训练模型和时间序列分解策略,处理趋势、季节性和噪声。
  2. 时间序列分解
    模型将输入的时间序列分为三个部分:趋势、季节性和噪声,并分别进行建模。首先,使用移动平均方法提取趋势,然后通过减去趋势提取季节性成分,最后将剩余的部分视为噪声。

  3. GPT-2 预训练模型
    模型使用了GPT-2预训练模型来建模每个分量(趋势、季节性和噪声)。如果使用了prompt机制,模型将额外添加一些输入提示(prompt),例如关于趋势和季节性的描述文本,并将这些提示作为额外输入的一部分。

  4. 多头提示选择
    如果启用了“pool”机制,模型可以从预定义的提示池中选择最相关的提示,通过相似性度量找到与当前输入最匹配的提示。这部分通过select_prompt函数完成。

  5. 模型的输入和输出
    输入是时间序列数据,模型首先通过移动平均提取趋势、季节性和噪声,然后将这些分量分别输入到对应的GPT-2模型中进行处理。最终将三个分量的输出相加,得到最后的预测结果。

  6. 损失计算
    模型不仅关注最终的预测结果,还会计算趋势、季节性和噪声的局部损失(local loss),通过最小化这些损失来进一步提高模型的预测性能。

  7. 训练参数设置
    通过print_trainable_parameters函数,可以显示模型中可训练参数的数量,这对于调试和优化模型非常有帮助。

总的来说,这个代码实现了一个复杂的时间序列预测模型,结合了预训练语言模型(GPT-2)和时间序列的趋势、季节性、噪声分解,能够用于捕捉时间序列中的不同成分,并进行准确的预测。

1. 导入库

import numpy as np
import torch
import torch.nn as nn
from torch import optim

from transformers.models.gpt2.modeling_gpt2 import GPT2Model
from transformers import BertTokenizer, BertModel
from einops import rearrange
from embed import DataEmbedding, DataEmbedding_wo_time
from transformers.models.gpt2.configuration_gpt2 import GPT2Config
from transformers import GPT2LMHeadModel, GPT2Tokenizer
from utils.rev_in import RevIn
from peft import get_peft_config, PeftModel, PeftConfig, get_peft_model, LoraConfig, TaskType
  • numpy: 用于处理数值计算。
  • torch: PyTorch深度学习框架。
  • torch.nn: 包含构建神经网络的基础模块。
  • transformers: 用于加载预训练模型(如GPT-2、BERT等)的库。
  • einops: 用于处理和变换张量的库,简化张量操作。
  • embed: 用户定义的模块,可能包括自定义的嵌入层。
  • utils.rev_in: 自定义模块,包含RevIn类,可能用于某种特殊的输入处理。
  • peft: 用于处理参数高效微调(PEFT)的库。

2. 定义损失函数

criterion = nn.MSELoss()
  • MSELoss: 定义均方误差损失,用于回归任务中衡量预测值和真实值之间的差距。

3. 定义复数线性层

class ComplexLinear(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(ComplexLinear, self).__init__()
        self.fc_real = nn.Linear(input_dim, output_dim)
        self.fc_imag = nn.Linear(input_dim, output_dim)

    def forward(self, x):
        x_real = torch.real(x)
        x_imag = torch.imag(x)
        out_real = self.fc_real(x_real) - self.fc_imag(x_imag)
        out_imag = self.fc_real(x_imag) + self.fc_imag(x_real)
        return torch.complex(out_real, out_imag)
  • ComplexLinear: 自定义线性层,处理复数输入。它使用两个线性层分别处理实部和虚部,并组合输出。
  • forward 方法: 将输入分为实部和虚部,经过线性层后重组为复数输出。

4. 打印可训练参数

def print_trainable_parameters(model):
    trainable_params = 0
    all_param = 0
    for _, param in model.named_parameters():
        all_param += param.numel()
        if param.requires_grad:
            trainable_params += param.numel()
    print(
        f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param:.2f}"
    )
  • print_trainable_parameters: 用于打印模型中可训练参数的数量和比例,帮助分析模型的复杂度和训练需求。

5. 定义多项傅里叶变换类

class MultiFourier(torch.nn.Module):
    def __init__(self, N, P):
        super(MultiFourier, self).__init__()
        self.N = N
        self.P = P
        self.a = torch.nn.Parameter(torch.randn(max(N), len(N)), requires_grad=True)
        self.b = torch.nn.Parameter(torch.randn(max(N), len(N)), requires_grad=True)

    def forward(self, t):
        output = torch.zeros_like(t)
        t = t.unsqueeze(-1).repeat(1, 1, max(self.N))  # shape: [batch_size, seq_len, max(N)]
        n = torch.arange(max(self.N)).unsqueeze(0).unsqueeze(0).to(t.device)  # shape: [1, 1, max(N)]
        for j in range(len(self.N)):  # loop over seasonal components
            cos_terms = torch.cos(2 * np.pi * (n[..., :self.N[j]] + 1) * t[..., :self.N[j]] / self.P[j])  # shape: [batch_size, seq_len, N[j]]
            sin_terms = torch.sin(2 * np.pi * (n[..., :self.N[j]] + 1) * t[..., :self.N[j]] / self.P[j])  # shape: [batch_size, seq_len, N[j]]
            output += torch.matmul(cos_terms, self.a[:self.N[j], j]) + torch.matmul(sin_terms, self.b[:self.N[j], j])
        return output
  • MultiFourier: 用于实现多项傅里叶变换,通过傅里叶级数对输入数据进行季节性建模。
  • forward 方法: 计算傅里叶变换的余弦和正弦分量,并将其与参数ab进行矩阵乘法,返回变换后的输出。

6. 定义移动平均类

class moving_avg(nn.Module):
    def __init__(self, kernel_size, stride):
        super(moving_avg, self).__init__()
        self.kernel_size = kernel_size
        self.avg = nn.AvgPool1d(kernel_size=kernel_size, stride=stride, padding=0)

    def forward(self, x):
        front = x[:, 0:1, :].repeat(1, (self.kernel_size - 1) // 2, 1)
        end = x[:, -1:, :].repeat(1, (self.kernel_size - 1) // 2, 1)
        x = torch.cat([front, x, end], dim=1)
        x = self.avg(x.permute(0, 2, 1))
        x = x.permute(0, 2, 1)
        return x
  • moving_avg: 用于实现移动平均,突出时间序列的趋势部分。
  • forward 方法: 通过在输入数据的前后添加填充,计算滑动平均,并返回平滑后的结果。

7. 定义TEMPO模型类

class TEMPO(nn.Module):
    
    def __init__(self, configs, device):
        super(TEMPO, self).__init__()
        self.prompt = getattr(configs, 'prompt', 0)
        self.is_gpt = configs.is_gpt
        self.patch_size = configs.patch_size
        self.pretrain = configs.pretrain
        self.stride = configs.stride
        self.patch_num = (configs.seq_len - self.patch_size) // self.stride + 1
        self.mul_season = MultiFourier([2], [24*4]) 

        self.padding_patch_layer = nn.ReplicationPad1d((0, self.stride)) 
        self.patch_num += 1

        self.map_trend = nn.Linear(configs.seq_len, configs.seq_len)
        self.map_season  = nn.Sequential(
            nn.Linear(configs.seq_len, 4*configs.seq_len),
            nn.ReLU(),
            nn.Linear(4*configs.seq_len, configs.seq_len)
        )

        self.map_resid = nn.Linear(configs.seq_len, configs.seq_len)

        kernel_size = 25
        self.moving_avg = moving_avg(kernel_size, stride=1)

        if configs.pretrain:
            local_model_path = './gpt2/' 
            self.gpt2_trend = GPT2Model.from_pretrained(local_model_path, output_attentions=True,
                                                        output_hidden_states=True)
        else:
            print("------------------no pretrain------------------")
            self.gpt2_trend = GPT2Model(GPT2Config())
            self.gpt2_season = GPT2Model(GPT2Config())
            self.gpt2_noise = GPT2Model(GPT2Config())

            self.gpt2_trend.h = self.gpt2_trend.h[:configs.gpt_layers]
           
            self.prompt = configs.prompt
            self.tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
            self.gpt2_trend_token = self.tokenizer(text="Predict the future time step given the trend", return_tensors="pt").to(device)
            self.gpt2_season_token = self.tokenizer(text="Predict the future time step given the season", return_tensors="pt").to(device)
            self.gpt2_residual_token = self.tokenizer(text="Predict the future time step given the residual", return_tensors="pt").to(device)

            self.token_len = len(self.gpt2_trend_token['input_ids'][0])

            try:
                self.pool = configs.pool
                if self.pool:
                    self.prompt_record_plot = {}
                    self.prompt_record_id = 0
                    self.diversify = True
            except:
                self.pool = False

            if self.pool:
                self.prompt_key_dict = nn.ParameterDict({})
                self.prompt_value_dict = nn.ParameterDict({})
                self.summary_map = nn.Linear(self.patch_num, 1)
                self.pool_size = 30
                self.top_k = 3
                self.prompt_len = 3
                self.token_len = self.prompt_len * self.top

_k
                for i in range(self.pool_size):
                    prompt_shape = (self.prompt_len, 768)
                    key_shape = (768)
                    self.prompt_value_dict[f"prompt_value_{i}"] = nn.Parameter(torch.randn(prompt_shape))
                    self.prompt_key_dict[f"prompt_key_{i}"] = nn.Parameter(torch.randn(key_shape))
            
                self.prompt_record = {f"id_{i}": 0 for i in range(self.pool_size)}
                self.prompt_record_trend = {}
                self.prompt_record_season = {}
                self.prompt_record_residual = {}
                self.diversify = True

        self.in_layer_trend = nn.Linear(configs.patch_size, configs.d_model)
        self.in_layer_season = nn.Linear(configs.patch_size, configs.d_model)
        self.in_layer_noise = nn.Linear(configs.patch_size, configs.d_model)

        if configs.prompt == 1:
            self.use_token = configs.use_token
            if self.use_token == 1:
                    self.out_layer_trend = nn.Linear(configs.d_model * (self.patch_num+self.token_len), configs.pred_len)
                    self.out_layer_season = nn.Linear(configs.d_model * (self.patch_num+self.token_len), configs.pred_len)
                    self.out_layer_noise = nn.Linear(configs.d_model * (self.patch_num+self.token_len), configs.pred_len)
            else:
                self.out_layer_trend = nn.Linear(configs.d_model * self.patch_num, configs.pred_len)
                self.out_layer_season = nn.Linear(configs.d_model * self.patch_num, configs.pred_len)
                self.out_layer_noise = nn.Linear(configs.d_model * self.patch_num, configs.pred_len)

            self.prompt_layer_trend = nn.Linear(configs.d_model, configs.d_model)
            self.prompt_layer_season = nn.Linear(configs.d_model, configs.d_model)
            self.prompt_layer_noise = nn.Linear(configs.d_model, configs.d_model)

            for layer in (self.prompt_layer_trend, self.prompt_layer_season, self.prompt_layer_noise):
                layer.to(device=device)
                layer.train()
        else:
            self.out_layer_trend = nn.Linear(configs.d_model * self.patch_num, configs.pred_len)
            self.out_layer_season = nn.Linear(configs.d_model * self.patch_num, configs.pred_len)
            self.out_layer_noise = nn.Linear(configs.d_model * self.patch_num, configs.pred_len)

        if configs.freeze and configs.pretrain:
            for i, (name, param) in enumerate(self.gpt2_trend.named_parameters()):
                if 'ln' in name or 'wpe' in name:
                    param.requires_grad = True
                else:
                    param.requires_grad = False

        config = LoraConfig(
            r=16,
            lora_alpha=16,
            lora_dropout=0.1,
            bias="lora_only",
        )
         
        self.gpt2_trend = get_peft_model(self.gpt2_trend, config)
        print_trainable_parameters(self.gpt2_trend)

        for layer in (self.gpt2_trend, self.in_layer_trend, self.out_layer_trend, 
                      self.in_layer_season, self.out_layer_season, self.in_layer_noise, self.out_layer_noise):
            layer.to(device=device)
            layer.train()

        for layer in (self.map_trend, self.map_season, self.map_resid):
            layer.to(device=device)
            layer.train()
        
        self.cnt = 0

        self.num_nodes = configs.num_nodes
        self.rev_in_trend = RevIn(num_features=self.num_nodes).to(device)
        self.rev_in_season = RevIn(num_features=self.num_nodes).to(device)
        self.rev_in_noise = RevIn(num_features=self.num_nodes).to(device)
  • TEMPO: 这是主模型类,整合了时间序列处理、GPT-2模型以及其他辅助模块。
  • __init__ 方法:
    • configs: 接收配置参数,设置模型的各个超参数。
    • patch_num: 根据输入序列的长度和步幅计算补丁数量,便于后续处理。
    • 多项傅里叶: 初始化MultiFourier类,用于捕捉季节性成分。
    • 移动平均: 初始化moving_avg类,用于平滑处理。
    • GPT-2模型: 根据是否预训练加载相应的GPT-2模型,或者初始化新的模型。
    • Tokenization: 使用GPT-2的分词器生成提示token。
    • Pool和Prompt: 根据配置初始化prompt相关的参数,如果启用了pool机制,建立用于选择的参数字典。
    • 输入和输出层: 定义输入层和输出层,分别用于处理趋势、季节性和噪声的输入。

8. 存储张量的函数

def store_tensors_in_dict(self, original_x, original_trend, original_season, original_noise, trend_prompts, season_prompts, noise_prompts):
    self.prompt_record_id += 1 
    for i in range(original_x.size(0)):
        self.prompt_record_plot[self.prompt_record_id + i] = {
            'original_x': original_x[i].tolist(),
            'original_trend': original_trend[i].tolist(),
            'original_season': original_season[i].tolist(),
            'original_noise': original_noise[i].tolist(),
            'trend_prompt': trend_prompts[i],
            'season_prompt': season_prompts[i],
            'noise_prompt': noise_prompts[i],
        }
  • store_tensors_in_dict: 用于将原始输入和处理后的各个成分(趋势、季节性、噪声)以及对应的提示存储在字典中,以便后续分析和可视化。

9. L2归一化

def l2_normalize(self, x, dim=None, epsilon=1e-12):
    square_sum = torch.sum(x ** 2, dim=dim, keepdim=True)
    x_inv_norm = torch.rsqrt(torch.maximum(square_sum, torch.tensor(epsilon, device=x.device)))
    return x * x_inv_norm
  • l2_normalize: 实现L2归一化,用于标准化输入张量的大小,防止过大的数值影响模型的学习过程。

10. 选择提示

def select_prompt(self, summary, prompt_mask=None):
    prompt_key_matrix = torch.stack(tuple([self.prompt_key_dict[i] for i in self.prompt_key_dict.keys()]))
    prompt_norm = self.l2_normalize(prompt_key_matrix, dim=1)
    summary_reshaped = summary.view(-1, self.patch_num)
    summary_mapped = self.summary_map(summary_reshaped)
    summary = summary_mapped.view(-1, 768)
    summary_embed_norm = self.l2_normalize(summary, dim=1)
    similarity = torch.matmul(summary_embed_norm, prompt_norm.t())
    if not prompt_mask==None:
        idx = prompt_mask
    else:
        topk_sim, idx = torch.topk(similarity, k=self.top_k, dim=1)
    if prompt_mask==None:
        count_of_keys = torch.bincount(torch.flatten(idx), minlength=15)
        for i in range(len(count_of_keys)):
            self.prompt_record[f"id_{i}"] += count_of_keys[i].item()

    prompt_value_matrix = torch.stack(tuple([self.prompt_value_dict[i] for i in self.prompt_value_dict.keys()]))
    batched_prompt_raw = prompt_value_matrix[idx].squeeze(1)
    batch_size, top_k, length, c = batched_prompt_raw.shape
    batched_prompt = batched_prompt_raw.reshape(batch_size, top_k * length, c) 
   
    batched_key_norm = prompt_norm[idx]
    summary_embed_norm = summary_embed_norm.unsqueeze(1)
    sim = batched_key_norm * summary_embed_norm
    reduce_sim = torch.sum(sim) / summary.shape[0]

    selected_prompts = [tuple(sorted(row)) for row in idx.tolist()]

    return batched_prompt, reduce_sim, selected_prompts
  • select_prompt: 从可用的提示中选择最相关的提示。根据输入的摘要与提示的相似度计算,返回所选提示及其相似度值。
  • 计算提示的相似度并选择最相关的提示,若有掩码则根据掩码选择。

11. 标准化与补丁获取

def get_norm(self, x, d = 'norm'):
    means = x.mean(1, keepdim=True).detach()
    x = x - means
    stdev = torch.sqrt(torch.var(x, dim=1, keepdim=True, unbiased=False) + 1e-5).detach() 
    x /= stdev

    return x, means, stdev

def get_patch(self, x):
    x = rearrange(x, 'b l m -> b m l')
    x = self.padding_patch_layer(x)
    x = x.unfold(dimension=-1, size=self.patch_size, step=self.stride)
    x = rearrange(x, 'b m n p -> (b m) n p')

    return x
  • get_norm: 计算输入的均值和标准差,用于标准化处理,使得每个输入在训练时均值为0,方差为1。
  • get_patch: 通过重组

输入张量并应用填充和滑动窗口,获取补丁数据,以便后续处理。

12. 获取嵌入

def get_emb(self, x, tokens=None, type='Trend'):
    if tokens is None:
        if type == 'Trend':
            x = self.gpt2_trend(inputs_embeds=x).last_hidden_state
        elif type == 'Season':
            x = self.gpt2_trend(inputs_embeds=x).last_hidden_state
        elif type == 'Residual':
            x = self.gpt2_trend(inputs_embeds=x).last_hidden_state
        return x
    else:
        [a, b, c] = x.shape
        if type == 'Trend': 
            if self.pool:
                prompt_x, reduce_sim, selected_prompts_trend = self.select_prompt(x, prompt_mask=None)
                for selected_prompt_trend in selected_prompts_trend:
                    self.prompt_record_trend[selected_prompt_trend] = self.prompt_record_trend.get(selected_prompt_trend, 0) + 1
                selected_prompts = selected_prompts_trend
            else:
                prompt_x = self.gpt2_trend.wte(tokens)
                prompt_x = prompt_x.repeat(a, 1, 1)
                prompt_x = self.prompt_layer_trend(prompt_x)
            x = torch.cat((prompt_x, x), dim=1)

        elif type == 'Season':
            if self.pool:
                prompt_x, reduce_sim, selected_prompts_season = self.select_prompt(x, prompt_mask=None)
                for selected_prompt_season in selected_prompts_season:
                    self.prompt_record_season[selected_prompt_season] = self.prompt_record_season.get(selected_prompt_season, 0) + 1
                selected_prompts = selected_prompts_season
            else:
                prompt_x = self.gpt2_trend.wte(tokens)
                prompt_x = prompt_x.repeat(a, 1, 1)
                prompt_x = self.prompt_layer_season(prompt_x)
            x = torch.cat((prompt_x, x), dim=1)

        elif type == 'Residual':
            if self.pool:
                prompt_x, reduce_sim, selected_prompts_resid = self.select_prompt(x, prompt_mask=None)
                for selected_prompt_resid in selected_prompts_resid:
                    self.prompt_record_residual[selected_prompt_resid] = self.prompt_record_residual.get(selected_prompt_resid, 0) + 1
                selected_prompts = selected_prompts_resid
            else:
                prompt_x = self.gpt2_trend.wte(tokens)
                prompt_x = prompt_x.repeat(a, 1, 1)
                prompt_x = self.prompt_layer_noise(prompt_x)
            x = torch.cat((prompt_x, x), dim=1)

    if self.pool:
        return x, reduce_sim, selected_prompts
    else:
        return x
  • get_emb: 获取嵌入表示。根据输入类型(趋势、季节性、残差)处理输入数据,可能会添加提示token的嵌入表示。
  • 对于每种类型,根据是否启用了pool机制选择处理逻辑,将提示token嵌入与原输入拼接。

13. 模型前向传播

def forward(self, x, itr, trend, season, noise, test=False):
    B, L, M = x.shape # 4, 512, 1

    x = self.rev_in_trend(x, 'norm')

    original_x = x
    
    trend_local = self.moving_avg(x)
    trend_local = self.map_trend(trend_local.squeeze()).unsqueeze(2)
    season_local = x - trend_local
    season_local = self.map_season(season_local.squeeze().unsqueeze(1)).squeeze(1).unsqueeze(2)
    noise_local = x - trend_local - season_local

    trend, means_trend, stdev_trend = self.get_norm(trend)
    season, means_season, stdev_season = self.get_norm(season)
    noise, means_noise, stdev_noise = self.get_norm(noise)

    if trend is not None:
        trend_local_l = criterion(trend, trend_local)
        season_local_l = criterion(season, season_local)
        noise_local_l = criterion(noise, noise_local)
        
        loss_local = trend_local_l + season_local_l + noise_local_l 
        if test:
            print("trend local loss:", torch.mean(trend_local_l))
            print("Season local loss", torch.mean(season_local_l))
            print("noise local loss", torch.mean(noise_local_l))

    trend = self.get_patch(trend_local)
    season = self.get_patch(season_local)
    noise = self.get_patch(noise_local)

    trend = self.in_layer_trend(trend) # 4, 64, 768
    if self.is_gpt and self.prompt == 1:
        if self.pool:
            trend, reduce_sim_trend, trend_selected_prompts = self.get_emb(trend, self.gpt2_trend_token['input_ids'], 'Trend')
        else:
            trend = self.get_emb(trend, self.gpt2_trend_token['input_ids'], 'Trend')
    else:
        trend = self.get_emb(trend)

    season = self.in_layer_season(season) # 4, 64, 768
    if self.is_gpt and self.prompt == 1:
        if self.pool:
            season, reduce_sim_season, season_selected_prompts = self.get_emb(season, self.gpt2_season_token['input_ids'], 'Season')
        else:
            season = self.get_emb(season, self.gpt2_season_token['input_ids'], 'Season')
    else:
        season = self.get_emb(season)

    noise = self.in_layer_noise(noise)
    if self.is_gpt and self.prompt == 1:
        if self.pool:
            noise, reduce_sim_noise, noise_selected_prompts = self.get_emb(noise, self.gpt2_residual_token['input_ids'], 'Residual')
        else:
            noise = self.get_emb(noise, self.gpt2_residual_token['input_ids'], 'Residual')
    else:
        noise = self.get_emb(noise)

    x_all = torch.cat((trend, season, noise), dim=1)

    x = self.gpt2_trend(inputs_embeds=x_all).last_hidden_state 
    
    if self.prompt == 1:
        trend  = x[:, :self.token_len+self.patch_num, :]  
        season  = x[:, self.token_len+self.patch_num:2*self.token_len+2*self.patch_num, :]  
        noise = x[:, 2*self.token_len+2*self.patch_num:, :]
        if self.use_token == 0:
            trend = trend[:, self.token_len:, :]
            season = season[:, self.token_len:, :]
            noise = noise[:, self.token_len:, :]    
    else:
        trend  = x[:, :self.patch_num, :]  
        season  = x[:, self.patch_num:2*self.patch_num, :]  
        noise = x[:, 2*self.patch_num:, :] 
        
    trend = self.out_layer_trend(trend.reshape(B*M, -1)) # 4, 96
    trend = rearrange(trend, '(b m) l -> b l m', b=B) # 4, 96, 1
    
    season = self.out_layer_season(season.reshape(B*M, -1)) # 4, 96
    season = rearrange(season, '(b m) l -> b l m', b=B) # 4, 96, 1

    noise = self.out_layer_noise(noise.reshape(B*M, -1)) # 4, 96
    noise = rearrange(noise, '(b m) l -> b l m', b=B)
    
    outputs = trend + season + noise 

    outputs = self.rev_in_trend(outputs, 'denorm')
    return outputs, loss_local
  • forward: 实现模型的前向传播逻辑,处理输入数据并计算输出。
    • 数据预处理: 首先对输入进行反归一化(rev_in_trend),计算趋势、季节性和噪声的局部值。
    • 损失计算: 使用均方误差计算趋势、季节性和噪声的损失,并在测试模式下打印损失。
    • 补丁获取: 将趋势、季节性和噪声的局部值转换为补丁形式。
    • 嵌入计算: 使用不同的嵌入层获取趋势、季节性和噪声的表示,可能根据是否启用提示进行不同的处理。
    • 合并输出: 将处理后的趋势、季节性和噪声的输出合并,得到最终的预测输出。

总结

该代码实现了一个复杂的时间序列预测模型,通过集成GPT-2和多项傅里叶变换技术,对时间序列的趋势、季节性和噪声进行建模。通过使用各种自定义模块和方法,该模型能够有效处理和预测时间序列数据,具有良好的灵活性和扩展性。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值