该代码实现了一个时间序列预测模型,主要基于GPT-2模型,并结合了趋势(trend)、季节性(season)、噪声(noise)等时间序列分解的概念。下面是详细解释:
-
模块结构:
ComplexLinear
: 这是一个自定义的线性层,用于处理复数输入。该层将输入的实部和虚部分开处理,并将其组合为复数输出。MultiFourier
: 用于执行多项傅里叶变换,目的是捕捉时间序列中的周期性成分。moving_avg
: 这是一个移动平均模块,用于平滑时间序列数据,从而突出趋势成分。TEMPO
: 主模型类,它整合了GPT-2预训练模型和时间序列分解策略,处理趋势、季节性和噪声。
-
时间序列分解:
模型将输入的时间序列分为三个部分:趋势、季节性和噪声,并分别进行建模。首先,使用移动平均方法提取趋势,然后通过减去趋势提取季节性成分,最后将剩余的部分视为噪声。 -
GPT-2 预训练模型:
模型使用了GPT-2预训练模型来建模每个分量(趋势、季节性和噪声)。如果使用了prompt机制,模型将额外添加一些输入提示(prompt),例如关于趋势和季节性的描述文本,并将这些提示作为额外输入的一部分。 -
多头提示选择:
如果启用了“pool”机制,模型可以从预定义的提示池中选择最相关的提示,通过相似性度量找到与当前输入最匹配的提示。这部分通过select_prompt
函数完成。 -
模型的输入和输出:
输入是时间序列数据,模型首先通过移动平均提取趋势、季节性和噪声,然后将这些分量分别输入到对应的GPT-2模型中进行处理。最终将三个分量的输出相加,得到最后的预测结果。 -
损失计算:
模型不仅关注最终的预测结果,还会计算趋势、季节性和噪声的局部损失(local loss),通过最小化这些损失来进一步提高模型的预测性能。 -
训练参数设置:
通过print_trainable_parameters
函数,可以显示模型中可训练参数的数量,这对于调试和优化模型非常有帮助。
总的来说,这个代码实现了一个复杂的时间序列预测模型,结合了预训练语言模型(GPT-2)和时间序列的趋势、季节性、噪声分解,能够用于捕捉时间序列中的不同成分,并进行准确的预测。
1. 导入库
import numpy as np
import torch
import torch.nn as nn
from torch import optim
from transformers.models.gpt2.modeling_gpt2 import GPT2Model
from transformers import BertTokenizer, BertModel
from einops import rearrange
from embed import DataEmbedding, DataEmbedding_wo_time
from transformers.models.gpt2.configuration_gpt2 import GPT2Config
from transformers import GPT2LMHeadModel, GPT2Tokenizer
from utils.rev_in import RevIn
from peft import get_peft_config, PeftModel, PeftConfig, get_peft_model, LoraConfig, TaskType
- numpy: 用于处理数值计算。
- torch: PyTorch深度学习框架。
- torch.nn: 包含构建神经网络的基础模块。
- transformers: 用于加载预训练模型(如GPT-2、BERT等)的库。
- einops: 用于处理和变换张量的库,简化张量操作。
- embed: 用户定义的模块,可能包括自定义的嵌入层。
- utils.rev_in: 自定义模块,包含
RevIn
类,可能用于某种特殊的输入处理。 - peft: 用于处理参数高效微调(PEFT)的库。
2. 定义损失函数
criterion = nn.MSELoss()
- MSELoss: 定义均方误差损失,用于回归任务中衡量预测值和真实值之间的差距。
3. 定义复数线性层
class ComplexLinear(nn.Module):
def __init__(self, input_dim, output_dim):
super(ComplexLinear, self).__init__()
self.fc_real = nn.Linear(input_dim, output_dim)
self.fc_imag = nn.Linear(input_dim, output_dim)
def forward(self, x):
x_real = torch.real(x)
x_imag = torch.imag(x)
out_real = self.fc_real(x_real) - self.fc_imag(x_imag)
out_imag = self.fc_real(x_imag) + self.fc_imag(x_real)
return torch.complex(out_real, out_imag)
- ComplexLinear: 自定义线性层,处理复数输入。它使用两个线性层分别处理实部和虚部,并组合输出。
forward
方法: 将输入分为实部和虚部,经过线性层后重组为复数输出。
4. 打印可训练参数
def print_trainable_parameters(model):
trainable_params = 0
all_param = 0
for _, param in model.named_parameters():
all_param += param.numel()
if param.requires_grad:
trainable_params += param.numel()
print(
f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param:.2f}"
)
- print_trainable_parameters: 用于打印模型中可训练参数的数量和比例,帮助分析模型的复杂度和训练需求。
5. 定义多项傅里叶变换类
class MultiFourier(torch.nn.Module):
def __init__(self, N, P):
super(MultiFourier, self).__init__()
self.N = N
self.P = P
self.a = torch.nn.Parameter(torch.randn(max(N), len(N)), requires_grad=True)
self.b = torch.nn.Parameter(torch.randn(max(N), len(N)), requires_grad=True)
def forward(self, t):
output = torch.zeros_like(t)
t = t.unsqueeze(-1).repeat(1, 1, max(self.N)) # shape: [batch_size, seq_len, max(N)]
n = torch.arange(max(self.N)).unsqueeze(0).unsqueeze(0).to(t.device) # shape: [1, 1, max(N)]
for j in range(len(self.N)): # loop over seasonal components
cos_terms = torch.cos(2 * np.pi * (n[..., :self.N[j]] + 1) * t[..., :self.N[j]] / self.P[j]) # shape: [batch_size, seq_len, N[j]]
sin_terms = torch.sin(2 * np.pi * (n[..., :self.N[j]] + 1) * t[..., :self.N[j]] / self.P[j]) # shape: [batch_size, seq_len, N[j]]
output += torch.matmul(cos_terms, self.a[:self.N[j], j]) + torch.matmul(sin_terms, self.b[:self.N[j], j])
return output
- MultiFourier: 用于实现多项傅里叶变换,通过傅里叶级数对输入数据进行季节性建模。
forward
方法: 计算傅里叶变换的余弦和正弦分量,并将其与参数a
和b
进行矩阵乘法,返回变换后的输出。
6. 定义移动平均类
class moving_avg(nn.Module):
def __init__(self, kernel_size, stride):
super(moving_avg, self).__init__()
self.kernel_size = kernel_size
self.avg = nn.AvgPool1d(kernel_size=kernel_size, stride=stride, padding=0)
def forward(self, x):
front = x[:, 0:1, :].repeat(1, (self.kernel_size - 1) // 2, 1)
end = x[:, -1:, :].repeat(1, (self.kernel_size - 1) // 2, 1)
x = torch.cat([front, x, end], dim=1)
x = self.avg(x.permute(0, 2, 1))
x = x.permute(0, 2, 1)
return x
- moving_avg: 用于实现移动平均,突出时间序列的趋势部分。
forward
方法: 通过在输入数据的前后添加填充,计算滑动平均,并返回平滑后的结果。
7. 定义TEMPO模型类
class TEMPO(nn.Module):
def __init__(self, configs, device):
super(TEMPO, self).__init__()
self.prompt = getattr(configs, 'prompt', 0)
self.is_gpt = configs.is_gpt
self.patch_size = configs.patch_size
self.pretrain = configs.pretrain
self.stride = configs.stride
self.patch_num = (configs.seq_len - self.patch_size) // self.stride + 1
self.mul_season = MultiFourier([2], [24*4])
self.padding_patch_layer = nn.ReplicationPad1d((0, self.stride))
self.patch_num += 1
self.map_trend = nn.Linear(configs.seq_len, configs.seq_len)
self.map_season = nn.Sequential(
nn.Linear(configs.seq_len, 4*configs.seq_len),
nn.ReLU(),
nn.Linear(4*configs.seq_len, configs.seq_len)
)
self.map_resid = nn.Linear(configs.seq_len, configs.seq_len)
kernel_size = 25
self.moving_avg = moving_avg(kernel_size, stride=1)
if configs.pretrain:
local_model_path = './gpt2/'
self.gpt2_trend = GPT2Model.from_pretrained(local_model_path, output_attentions=True,
output_hidden_states=True)
else:
print("------------------no pretrain------------------")
self.gpt2_trend = GPT2Model(GPT2Config())
self.gpt2_season = GPT2Model(GPT2Config())
self.gpt2_noise = GPT2Model(GPT2Config())
self.gpt2_trend.h = self.gpt2_trend.h[:configs.gpt_layers]
self.prompt = configs.prompt
self.tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
self.gpt2_trend_token = self.tokenizer(text="Predict the future time step given the trend", return_tensors="pt").to(device)
self.gpt2_season_token = self.tokenizer(text="Predict the future time step given the season", return_tensors="pt").to(device)
self.gpt2_residual_token = self.tokenizer(text="Predict the future time step given the residual", return_tensors="pt").to(device)
self.token_len = len(self.gpt2_trend_token['input_ids'][0])
try:
self.pool = configs.pool
if self.pool:
self.prompt_record_plot = {}
self.prompt_record_id = 0
self.diversify = True
except:
self.pool = False
if self.pool:
self.prompt_key_dict = nn.ParameterDict({})
self.prompt_value_dict = nn.ParameterDict({})
self.summary_map = nn.Linear(self.patch_num, 1)
self.pool_size = 30
self.top_k = 3
self.prompt_len = 3
self.token_len = self.prompt_len * self.top
_k
for i in range(self.pool_size):
prompt_shape = (self.prompt_len, 768)
key_shape = (768)
self.prompt_value_dict[f"prompt_value_{i}"] = nn.Parameter(torch.randn(prompt_shape))
self.prompt_key_dict[f"prompt_key_{i}"] = nn.Parameter(torch.randn(key_shape))
self.prompt_record = {f"id_{i}": 0 for i in range(self.pool_size)}
self.prompt_record_trend = {}
self.prompt_record_season = {}
self.prompt_record_residual = {}
self.diversify = True
self.in_layer_trend = nn.Linear(configs.patch_size, configs.d_model)
self.in_layer_season = nn.Linear(configs.patch_size, configs.d_model)
self.in_layer_noise = nn.Linear(configs.patch_size, configs.d_model)
if configs.prompt == 1:
self.use_token = configs.use_token
if self.use_token == 1:
self.out_layer_trend = nn.Linear(configs.d_model * (self.patch_num+self.token_len), configs.pred_len)
self.out_layer_season = nn.Linear(configs.d_model * (self.patch_num+self.token_len), configs.pred_len)
self.out_layer_noise = nn.Linear(configs.d_model * (self.patch_num+self.token_len), configs.pred_len)
else:
self.out_layer_trend = nn.Linear(configs.d_model * self.patch_num, configs.pred_len)
self.out_layer_season = nn.Linear(configs.d_model * self.patch_num, configs.pred_len)
self.out_layer_noise = nn.Linear(configs.d_model * self.patch_num, configs.pred_len)
self.prompt_layer_trend = nn.Linear(configs.d_model, configs.d_model)
self.prompt_layer_season = nn.Linear(configs.d_model, configs.d_model)
self.prompt_layer_noise = nn.Linear(configs.d_model, configs.d_model)
for layer in (self.prompt_layer_trend, self.prompt_layer_season, self.prompt_layer_noise):
layer.to(device=device)
layer.train()
else:
self.out_layer_trend = nn.Linear(configs.d_model * self.patch_num, configs.pred_len)
self.out_layer_season = nn.Linear(configs.d_model * self.patch_num, configs.pred_len)
self.out_layer_noise = nn.Linear(configs.d_model * self.patch_num, configs.pred_len)
if configs.freeze and configs.pretrain:
for i, (name, param) in enumerate(self.gpt2_trend.named_parameters()):
if 'ln' in name or 'wpe' in name:
param.requires_grad = True
else:
param.requires_grad = False
config = LoraConfig(
r=16,
lora_alpha=16,
lora_dropout=0.1,
bias="lora_only",
)
self.gpt2_trend = get_peft_model(self.gpt2_trend, config)
print_trainable_parameters(self.gpt2_trend)
for layer in (self.gpt2_trend, self.in_layer_trend, self.out_layer_trend,
self.in_layer_season, self.out_layer_season, self.in_layer_noise, self.out_layer_noise):
layer.to(device=device)
layer.train()
for layer in (self.map_trend, self.map_season, self.map_resid):
layer.to(device=device)
layer.train()
self.cnt = 0
self.num_nodes = configs.num_nodes
self.rev_in_trend = RevIn(num_features=self.num_nodes).to(device)
self.rev_in_season = RevIn(num_features=self.num_nodes).to(device)
self.rev_in_noise = RevIn(num_features=self.num_nodes).to(device)
- TEMPO: 这是主模型类,整合了时间序列处理、GPT-2模型以及其他辅助模块。
__init__
方法:- configs: 接收配置参数,设置模型的各个超参数。
- patch_num: 根据输入序列的长度和步幅计算补丁数量,便于后续处理。
- 多项傅里叶: 初始化
MultiFourier
类,用于捕捉季节性成分。 - 移动平均: 初始化
moving_avg
类,用于平滑处理。 - GPT-2模型: 根据是否预训练加载相应的GPT-2模型,或者初始化新的模型。
- Tokenization: 使用GPT-2的分词器生成提示token。
- Pool和Prompt: 根据配置初始化prompt相关的参数,如果启用了pool机制,建立用于选择的参数字典。
- 输入和输出层: 定义输入层和输出层,分别用于处理趋势、季节性和噪声的输入。
8. 存储张量的函数
def store_tensors_in_dict(self, original_x, original_trend, original_season, original_noise, trend_prompts, season_prompts, noise_prompts):
self.prompt_record_id += 1
for i in range(original_x.size(0)):
self.prompt_record_plot[self.prompt_record_id + i] = {
'original_x': original_x[i].tolist(),
'original_trend': original_trend[i].tolist(),
'original_season': original_season[i].tolist(),
'original_noise': original_noise[i].tolist(),
'trend_prompt': trend_prompts[i],
'season_prompt': season_prompts[i],
'noise_prompt': noise_prompts[i],
}
- store_tensors_in_dict: 用于将原始输入和处理后的各个成分(趋势、季节性、噪声)以及对应的提示存储在字典中,以便后续分析和可视化。
9. L2归一化
def l2_normalize(self, x, dim=None, epsilon=1e-12):
square_sum = torch.sum(x ** 2, dim=dim, keepdim=True)
x_inv_norm = torch.rsqrt(torch.maximum(square_sum, torch.tensor(epsilon, device=x.device)))
return x * x_inv_norm
- l2_normalize: 实现L2归一化,用于标准化输入张量的大小,防止过大的数值影响模型的学习过程。
10. 选择提示
def select_prompt(self, summary, prompt_mask=None):
prompt_key_matrix = torch.stack(tuple([self.prompt_key_dict[i] for i in self.prompt_key_dict.keys()]))
prompt_norm = self.l2_normalize(prompt_key_matrix, dim=1)
summary_reshaped = summary.view(-1, self.patch_num)
summary_mapped = self.summary_map(summary_reshaped)
summary = summary_mapped.view(-1, 768)
summary_embed_norm = self.l2_normalize(summary, dim=1)
similarity = torch.matmul(summary_embed_norm, prompt_norm.t())
if not prompt_mask==None:
idx = prompt_mask
else:
topk_sim, idx = torch.topk(similarity, k=self.top_k, dim=1)
if prompt_mask==None:
count_of_keys = torch.bincount(torch.flatten(idx), minlength=15)
for i in range(len(count_of_keys)):
self.prompt_record[f"id_{i}"] += count_of_keys[i].item()
prompt_value_matrix = torch.stack(tuple([self.prompt_value_dict[i] for i in self.prompt_value_dict.keys()]))
batched_prompt_raw = prompt_value_matrix[idx].squeeze(1)
batch_size, top_k, length, c = batched_prompt_raw.shape
batched_prompt = batched_prompt_raw.reshape(batch_size, top_k * length, c)
batched_key_norm = prompt_norm[idx]
summary_embed_norm = summary_embed_norm.unsqueeze(1)
sim = batched_key_norm * summary_embed_norm
reduce_sim = torch.sum(sim) / summary.shape[0]
selected_prompts = [tuple(sorted(row)) for row in idx.tolist()]
return batched_prompt, reduce_sim, selected_prompts
- select_prompt: 从可用的提示中选择最相关的提示。根据输入的摘要与提示的相似度计算,返回所选提示及其相似度值。
- 计算提示的相似度并选择最相关的提示,若有掩码则根据掩码选择。
11. 标准化与补丁获取
def get_norm(self, x, d = 'norm'):
means = x.mean(1, keepdim=True).detach()
x = x - means
stdev = torch.sqrt(torch.var(x, dim=1, keepdim=True, unbiased=False) + 1e-5).detach()
x /= stdev
return x, means, stdev
def get_patch(self, x):
x = rearrange(x, 'b l m -> b m l')
x = self.padding_patch_layer(x)
x = x.unfold(dimension=-1, size=self.patch_size, step=self.stride)
x = rearrange(x, 'b m n p -> (b m) n p')
return x
- get_norm: 计算输入的均值和标准差,用于标准化处理,使得每个输入在训练时均值为0,方差为1。
- get_patch: 通过重组
输入张量并应用填充和滑动窗口,获取补丁数据,以便后续处理。
12. 获取嵌入
def get_emb(self, x, tokens=None, type='Trend'):
if tokens is None:
if type == 'Trend':
x = self.gpt2_trend(inputs_embeds=x).last_hidden_state
elif type == 'Season':
x = self.gpt2_trend(inputs_embeds=x).last_hidden_state
elif type == 'Residual':
x = self.gpt2_trend(inputs_embeds=x).last_hidden_state
return x
else:
[a, b, c] = x.shape
if type == 'Trend':
if self.pool:
prompt_x, reduce_sim, selected_prompts_trend = self.select_prompt(x, prompt_mask=None)
for selected_prompt_trend in selected_prompts_trend:
self.prompt_record_trend[selected_prompt_trend] = self.prompt_record_trend.get(selected_prompt_trend, 0) + 1
selected_prompts = selected_prompts_trend
else:
prompt_x = self.gpt2_trend.wte(tokens)
prompt_x = prompt_x.repeat(a, 1, 1)
prompt_x = self.prompt_layer_trend(prompt_x)
x = torch.cat((prompt_x, x), dim=1)
elif type == 'Season':
if self.pool:
prompt_x, reduce_sim, selected_prompts_season = self.select_prompt(x, prompt_mask=None)
for selected_prompt_season in selected_prompts_season:
self.prompt_record_season[selected_prompt_season] = self.prompt_record_season.get(selected_prompt_season, 0) + 1
selected_prompts = selected_prompts_season
else:
prompt_x = self.gpt2_trend.wte(tokens)
prompt_x = prompt_x.repeat(a, 1, 1)
prompt_x = self.prompt_layer_season(prompt_x)
x = torch.cat((prompt_x, x), dim=1)
elif type == 'Residual':
if self.pool:
prompt_x, reduce_sim, selected_prompts_resid = self.select_prompt(x, prompt_mask=None)
for selected_prompt_resid in selected_prompts_resid:
self.prompt_record_residual[selected_prompt_resid] = self.prompt_record_residual.get(selected_prompt_resid, 0) + 1
selected_prompts = selected_prompts_resid
else:
prompt_x = self.gpt2_trend.wte(tokens)
prompt_x = prompt_x.repeat(a, 1, 1)
prompt_x = self.prompt_layer_noise(prompt_x)
x = torch.cat((prompt_x, x), dim=1)
if self.pool:
return x, reduce_sim, selected_prompts
else:
return x
- get_emb: 获取嵌入表示。根据输入类型(趋势、季节性、残差)处理输入数据,可能会添加提示token的嵌入表示。
- 对于每种类型,根据是否启用了pool机制选择处理逻辑,将提示token嵌入与原输入拼接。
13. 模型前向传播
def forward(self, x, itr, trend, season, noise, test=False):
B, L, M = x.shape # 4, 512, 1
x = self.rev_in_trend(x, 'norm')
original_x = x
trend_local = self.moving_avg(x)
trend_local = self.map_trend(trend_local.squeeze()).unsqueeze(2)
season_local = x - trend_local
season_local = self.map_season(season_local.squeeze().unsqueeze(1)).squeeze(1).unsqueeze(2)
noise_local = x - trend_local - season_local
trend, means_trend, stdev_trend = self.get_norm(trend)
season, means_season, stdev_season = self.get_norm(season)
noise, means_noise, stdev_noise = self.get_norm(noise)
if trend is not None:
trend_local_l = criterion(trend, trend_local)
season_local_l = criterion(season, season_local)
noise_local_l = criterion(noise, noise_local)
loss_local = trend_local_l + season_local_l + noise_local_l
if test:
print("trend local loss:", torch.mean(trend_local_l))
print("Season local loss", torch.mean(season_local_l))
print("noise local loss", torch.mean(noise_local_l))
trend = self.get_patch(trend_local)
season = self.get_patch(season_local)
noise = self.get_patch(noise_local)
trend = self.in_layer_trend(trend) # 4, 64, 768
if self.is_gpt and self.prompt == 1:
if self.pool:
trend, reduce_sim_trend, trend_selected_prompts = self.get_emb(trend, self.gpt2_trend_token['input_ids'], 'Trend')
else:
trend = self.get_emb(trend, self.gpt2_trend_token['input_ids'], 'Trend')
else:
trend = self.get_emb(trend)
season = self.in_layer_season(season) # 4, 64, 768
if self.is_gpt and self.prompt == 1:
if self.pool:
season, reduce_sim_season, season_selected_prompts = self.get_emb(season, self.gpt2_season_token['input_ids'], 'Season')
else:
season = self.get_emb(season, self.gpt2_season_token['input_ids'], 'Season')
else:
season = self.get_emb(season)
noise = self.in_layer_noise(noise)
if self.is_gpt and self.prompt == 1:
if self.pool:
noise, reduce_sim_noise, noise_selected_prompts = self.get_emb(noise, self.gpt2_residual_token['input_ids'], 'Residual')
else:
noise = self.get_emb(noise, self.gpt2_residual_token['input_ids'], 'Residual')
else:
noise = self.get_emb(noise)
x_all = torch.cat((trend, season, noise), dim=1)
x = self.gpt2_trend(inputs_embeds=x_all).last_hidden_state
if self.prompt == 1:
trend = x[:, :self.token_len+self.patch_num, :]
season = x[:, self.token_len+self.patch_num:2*self.token_len+2*self.patch_num, :]
noise = x[:, 2*self.token_len+2*self.patch_num:, :]
if self.use_token == 0:
trend = trend[:, self.token_len:, :]
season = season[:, self.token_len:, :]
noise = noise[:, self.token_len:, :]
else:
trend = x[:, :self.patch_num, :]
season = x[:, self.patch_num:2*self.patch_num, :]
noise = x[:, 2*self.patch_num:, :]
trend = self.out_layer_trend(trend.reshape(B*M, -1)) # 4, 96
trend = rearrange(trend, '(b m) l -> b l m', b=B) # 4, 96, 1
season = self.out_layer_season(season.reshape(B*M, -1)) # 4, 96
season = rearrange(season, '(b m) l -> b l m', b=B) # 4, 96, 1
noise = self.out_layer_noise(noise.reshape(B*M, -1)) # 4, 96
noise = rearrange(noise, '(b m) l -> b l m', b=B)
outputs = trend + season + noise
outputs = self.rev_in_trend(outputs, 'denorm')
return outputs, loss_local
- forward: 实现模型的前向传播逻辑,处理输入数据并计算输出。
- 数据预处理: 首先对输入进行反归一化(
rev_in_trend
),计算趋势、季节性和噪声的局部值。 - 损失计算: 使用均方误差计算趋势、季节性和噪声的损失,并在测试模式下打印损失。
- 补丁获取: 将趋势、季节性和噪声的局部值转换为补丁形式。
- 嵌入计算: 使用不同的嵌入层获取趋势、季节性和噪声的表示,可能根据是否启用提示进行不同的处理。
- 合并输出: 将处理后的趋势、季节性和噪声的输出合并,得到最终的预测输出。
- 数据预处理: 首先对输入进行反归一化(
总结
该代码实现了一个复杂的时间序列预测模型,通过集成GPT-2和多项傅里叶变换技术,对时间序列的趋势、季节性和噪声进行建模。通过使用各种自定义模块和方法,该模型能够有效处理和预测时间序列数据,具有良好的灵活性和扩展性。