informer 辅助笔记:main_informer.py

运行 informer的主文件

import argparse
import os
import torch

from exp.exp_informer import Exp_Informer

1 参数

parser.add_argument的这些

参数名称参数描述
model实验模型。可以设置为informer、informerstack、informerlight(TBD)
data数据集名称
root_path数据文件的根路径(默认为./data/ETT/)
data_path数据文件名称(默认为ETTh1.csv)
features

预测任务(默认为M)。可以设置为M、S、MS

(M:多变量预测多变量,S:单变量预测单变量,MS:多变量预测单变量)

targetS或MS任务中的目标特征(默认为OT)
freq

时间特征编码的频率(默认为h)

可以设置为s(秒)、t(分钟)、h(小时)、d(日)、b(工作日)、w(周)、m(月)。也可以使用更详细的频率,如15min或3h

checkpoints模型检查点的位置(默认为./checkpoints/)
seq_lenInformer编码器的输入序列长度(默认为96)
label_lenInformer解码器的起始标记长度(默认为48)
pred_len预测序列长度(默认为24)
enc_in编码器输入大小(默认为7)
dec_in解码器输入大小(默认为7)
c_out输出大小(默认为7)
d_model模型的维度(默认为512)
n_heads头的数量(默认为8)
e_layers编码器层的数量(默认为2)
d_layers解码器层的数量(默认为1)
s_layers堆叠编码器层的数量(默认为3,2,1)
d_fffcn的维度(默认为2048)
factorProbsparse attn因子(默认为5)
padding填充类型(默认为0)
distil是否在编码器中使用提炼,使用此参数表示不使用提炼(默认为True)
dropout丢弃的概率(默认为0.05)
attn编码器中使用的注意力(默认为prob)。可以设置为prob(informer)、full(transformer)
embed时间特征的编码(默认为timeF)。可以设置为timeF、fixed、learned
activation激活函数(默认为gelu)
output_attention是否在编码器中输出注意力,使用此参数表示输出注意力(默认为False)
do_predict是否预测未见的未来数据,使用此参数表示进行预测(默认为False)
mix是否在生成解码器中使用混合注意力,使用此参数表示不使用混合注意力(默认为True)
cols数据文件中作为输入特征的某些列
num_workersData loader的工作数(默认为0)
itr实验次数(默认为2)
train_epochs训练周期(默认为6)
batch_size训练输入数据的批量大小(默认为32)
patience提前停止的耐心(默认为3)
learning_rate优化器学习率(默认为0.0001)
des实验描述(默认为test)
loss损失函数(默认为mse)
lradj调整学习率的方式(默认为type1)
use_amp是否使用自动混合精度训练,使用此参数表示使用amp(默认为False)
inverse是否反转输出数据,使用此参数表示反转输出数据(默认为False)
use_gpu是否使用gpu(默认为True)
gpu用于训练和推理的gpu编号(默认为0)
use_multi_gpu是否使用多个gpu,使用此参数表示使用多个gpu(默认为False)
devices多个gpu的设备ID(默认为0,1,2,3)

2 其他部分

2.1 GPU 相关

args = parser.parse_args()
#解析命令行参数

args.use_gpu = True if torch.cuda.is_available() and args.use_gpu else False
#检查是否可以使用 GPU

if args.use_gpu and args.use_multi_gpu:
    args.devices = args.devices.replace(' ','')
    device_ids = args.devices.split(',')
    args.device_ids = [int(id_) for id_ in device_ids]
    args.gpu = args.device_ids[0]
    #如果启用了 GPU 且设置了多 GPU 使用,代码会解析 GPU 设备 ID,并准备相应的 GPU 设置。


2.2 数据集相关

data_parser = {
    'ETTh1':{'data':'ETTh1.csv','T':'OT','M':[7,7,7],'S':[1,1,1],'MS':[7,7,1]},
    'ETTh2':{'data':'ETTh2.csv','T':'OT','M':[7,7,7],'S':[1,1,1],'MS':[7,7,1]},
    'ETTm1':{'data':'ETTm1.csv','T':'OT','M':[7,7,7],'S':[1,1,1],'MS':[7,7,1]},
    'ETTm2':{'data':'ETTm2.csv','T':'OT','M':[7,7,7],'S':[1,1,1],'MS':[7,7,1]},
    'WTH':{'data':'WTH.csv','T':'WetBulbCelsius','M':[12,12,12],'S':[1,1,1],'MS':[12,12,1]},
    'ECL':{'data':'ECL.csv','T':'MT_320','M':[321,321,321],'S':[1,1,1],'MS':[321,321,1]},
    'Solar':{'data':'solar_AL.csv','T':'POWER_136','M':[137,137,137],'S':[1,1,1],'MS':[137,137,1]},
}
'''
这是一个数据解析器字典,包含不同数据集的配置信息

如文件名,目标列,输入输出目标维度

(M:多变量预测多变量,S:单变量预测单变量,MS:多变量预测单变量)
'''


if args.data in data_parser.keys():
    data_info = data_parser[args.data]
    args.data_path = data_info['data']
    args.target = data_info['T']
    args.enc_in, args.dec_in, args.c_out = data_info[args.features]
'''
检查输入的数据集是否在 data_parser 中定义,如果是,则从字典中获取相应的配置。

数据路径、目标列、encoder输入、decoder输入、decoder输出的维度
'''

2.3 设置参数


args.s_layers = [int(s_l) for s_l in args.s_layers.replace(' ','').split(',')]
#解析并设置堆叠层的数量

args.detail_freq = args.freq
#将频率详细信息保存在另一个参数中

args.freq = args.freq[-1:]
#频率的最后一个元素(一般是h,s,m这些)

print('Args in experiment:')
print(args)

2.4

Exp = Exp_Informer

for ii in range(args.itr):
    # 对于每次迭代,根据实验参数设置进行训练和测试。
    setting = '{}_{}_ft{}_sl{}_ll{}_pl{}_dm{}_nh{}_el{}_dl{}_df{}_at{}_fc{}_eb{}_dt{}_mx{}_{}_{}'.format(args.model, args.data, args.features, 
                args.seq_len, args.label_len, args.pred_len,
                args.d_model, args.n_heads, args.e_layers, args.d_layers, args.d_ff, args.attn, args.factor, 
                args.embed, args.distil, args.mix, args.des, ii)

    exp = Exp(args) # 使用给定参数实例化实验对象
    print('>>>>>>>start training : {}>>>>>>>>>>>>>>>>>>>>>>>>>>'.format(setting))
    exp.train(setting)
    
    print('>>>>>>>testing : {}<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<'.format(setting))
    exp.test(setting)
    #分别对模型进行训练和测试

    if args.do_predict:
        print('>>>>>>>predicting : {}<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<'.format(setting))
        exp.predict(setting, True)
    #如果设置为进行预测,那么执行预测

    torch.cuda.empty_cache()

  • 22
    点赞
  • 29
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

UQI-LIUWJ

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值