如果大家不知道informer改进了哪些以及创新点可以去看我的博文链接: 从0开始学习informer
环境pyhthon == 3.10 pytorch安装最新的就行了
1.获取数据
工欲善其事必先利其器,这里选取的股价数据,是以沪深500的工商银行19-24年1月的数据进行为例子。
我是在聚宽上面获取的数据。
这里不多简述在data里面
#获取数据列的参数和画图横纵坐标参数,以及数据集名称
x_columns, y_columns,xt_columns,yt_columns, plt_x_label, plt_y_label, train_size,data_name = Acquire_data_column()
讲一下Acquire_data_column()函数的作用, x_columns是我们需要训练需要用到的列,因为informer创新点是带入的时间维度,我们的数据是按照day进行的,所以xt_columns对应的就是第0维的列和我们需要训练的数据, yt_columns就是label所需要的参数。我们第一列是开盘价,就以预测开盘价为例子,但是一般都是预测收盘价。你可以自行修改。
下面是具体源码
def Acquire_data_column():
data_name = "gsyh.csv"#随便取,主要是可以区分不同的数据集创建出不同的文件夹,方便我们训练
plt_x_label = 'time/day'#画图参数
plt_y_label = 'Open/rmb'#画图参数
train_size = 0.8
x_columns = [1,2, 3, 4, 5,6]#训练的时候encoder的输入维度对应的列
y_columns = 1#训练的时候decoder的输入维度对应的列
xt_columns= [0, 2, 3, 4, 5, 6]#训练的时候decoder的输出维度对应的列,但是带时间因为informer创新点是带时间的 [0,1]中的0就是时间(date)列
yt_columns = [0, 1]#训练的时候decoder的输出维度对应的列,但是带时间因为informer创新点是带时间的
return x_columns, y_columns,xt_columns,yt_columns, plt_x_label, plt_y_label, train_size,data_name
定义模型的名称,方便创建文件夹好辨认
主函数里面还有Acquire_files,这是获取我们数据所在的路径,主要是。
model_name=f'Informer_model_{data_name}'
#获取文件夹路径名称
save_model_dict_path,png_save_path,png_filename,data_path =Acquire_files(model_name=model_name)
def Acquire_files(model_name):
save_model_dict_path=f'informer_add/dict/{model_name}'#保存最优权重的路径
png_save_path = f'informer_add/picture/{model_name}'
png_filename = f'{model_name}'#图片保存名字
data_path='informer_add/datas/601398.XSHG(工商银行14-24).csv'#你要训练的数据路径
return save_model_dict_path,png_save_path,png_filename,data_path
然后接下来是一些模型参数了关于这些可以在我的文章中得到链接: 从0开始的informer代码解读
make_data_config = DataConfig(data_path=data_path,x_columns=x_columns,y_columns=y_columns,xt_columns=xt_columns,yt_columns=yt_columns,train_size=train_size)
model_config = ModelConfig(enc_in=enc_in,dec_in = dec_in,c_out=c_out,model_name=model_name,lr=lr,epochs=epochs,dropout=dropout,device=device,pre_len=pre_len,batch_size=batch_size,s_len=s_len,best_model_path=save_model_dict_path)
2.创建数据集
至于获取数据集里面的内容细节是label可以是一维的,因为我们预测是一个MS多维预测单维度
所以我们在取数据的时候不能把我们要预测的维度取过来,不然会导致数据泄露,打不到预测的目的了。关于我自定义类可以在我分享的源码里面看。这里不过多赘述。
#获取数据集
train_x, test_x, train_y, test_y,x_stand,y_stand= read_data(make_data_config=make_data_config,model_config=model_config)
train_data = SocketData(train_x, train_y,x_stand,y_stand)
train_data = DataLoader(train_data, shuffle=True, batch_size=batch_size)
test_data = SocketData(test_x, test_y,x_stand,y_stand)
test_data = DataLoader(test_data, shuffle=True, batch_size=batch_size)
3.最后创建模型和训练并且打印出我们需要的图片
#创建模型
model = Informer(model_config).to(device)
#训练和评估模型
losses, predictions, targets=train_and_test(model=model,train_data=train_data,test_data=test_data,y_scaler=y_stand,model_config=model_config)
#保存图片
plot_and_save(losses, predictions[-1], targets[-1],png_save_path=png_save_path,png_filename=png_filename,plt_x_label=plt_x_label,plt_y_label=plt_y_label)
讲解一下参数
主要就是下面三个,对应的就是获取的x_columns和y_columns的对应列数 enc_in对应x_columns ,dec_in 对应y_columns。比如5列enc_in就是5.
enc_in = 5#编码器输入维度
dec_in = 1#解码器输入维度
c_out = 1 #解码器输出维度
效果如下
总代码
from torch.utils.data import DataLoader
from sklearn.preprocessing import StandardScaler
from models.Informer_model.model import Informer
from untils.Data import read_data,DataConfig,Acquire_data_column
from untils.ObjectClass import SocketData,ModelConfig
from untils.evaluation_function import train_and_test,plot_and_save
from untils.Creat_file import create_directories,Acquire_files
x_stand = StandardScaler()
y_stand = StandardScaler()
s_len = 64
pre_len = 3
batch_size = 64
device = "cuda"
lr = 0.001
epochs = 50
dropout = 0.05
enc_in = 5#编码器输入维度
dec_in = 1#解码器输入维度
c_out = 1 #解码器输出维度
#获取数据列的参数和画图横纵坐标参数,以及数据集名称
x_columns, y_columns,xt_columns,yt_columns, plt_x_label, plt_y_label, train_size,data_name = Acquire_data_column()
#定义文件夹路径名称
model_name=f'Informer_model_{data_name}'
#获取文件夹路径名称
save_model_dict_path,png_save_path,png_filename,data_path =Acquire_files(model_name=model_name)
# 调用函数创建文件夹
create_directories(save_model_dict_path, png_save_path,clear_dir=True)#clear_dir是否清空文件夹的选择
make_data_config = DataConfig(data_path=data_path,x_columns=x_columns,y_columns=y_columns,xt_columns=xt_columns,yt_columns=yt_columns,train_size=train_size)
model_config = ModelConfig(enc_in=enc_in,dec_in = dec_in,c_out=c_out,model_name=model_name,lr=lr,epochs=epochs,dropout=dropout,device=device,pre_len=pre_len,batch_size=batch_size,s_len=s_len,best_model_path=save_model_dict_path)
#获取数据集
train_x, test_x, train_y, test_y,x_stand,y_stand= read_data(make_data_config=make_data_config,model_config=model_config)
train_data = SocketData(train_x, train_y,x_stand,y_stand)
train_data = DataLoader(train_data, shuffle=True, batch_size=batch_size)
test_data = SocketData(test_x, test_y,x_stand,y_stand)
test_data = DataLoader(test_data, shuffle=True, batch_size=batch_size)
#创建模型
model = Informer(model_config).to(device)
#训练和评估模型
losses, predictions, targets=train_and_test(model=model,train_data=train_data,test_data=test_data,y_scaler=y_stand,model_config=model_config)
#保存图片
plot_and_save(losses, predictions[-1], targets[-1],png_save_path=png_save_path,png_filename=png_filename,plt_x_label=plt_x_label,plt_y_label=plt_y_label)
全部代码会在群里更新,讲的这么简介的原因是我的博文中有提到一些基本步骤不懂的可以看链接: nlp时序模型股价预测的基本思路(持续更新)
交流群在简介,也是欢迎大佬加入,一起交流。