前面的文章已经介绍,将短线个股挖掘问题转化为深度学习处理的分类问题,并且已经完成训练,将训练得到的模型保存到本地。本文将记录如何使用Keras加载模型并进行预测的过程。
结果预测
首先,找到训练模型保存的目录,加载模型:
# 加载模型
loaded_model = keras.models.load_model('./model/{}'.format(stk_code))
然后,读入数据,将数据转化为字典类型作为预测所使用的输入字典,键为特征的索引,值为tensor。我们使用了220个特征,索引值依次为0至219。
# 读入数据
data_file = './baostock/prediction_data_pre/{}.csv'.format(stk_code)
in_df = pd.read_csv(data_file)
# 预测用的输入字典
temp_dict = {}
# 将数据导入输入字典
for i in range(in_df.shape[1]):
temp_dict[i] = in_df['{}'.format(i)].tolist()
input_dict = {name: tf.convert_to_tensor(value) for name, value in temp_dict.items()}
接着,调用模型的predict方法进行预测,将预测结果保存到列表results中。
# 进行预测
predictions = loaded_model.predict(input_dict)
results = []
for i in range(in_df.shape[0]):
results.append(predictions[i][0])
然后,我们在未来用于回测的数据后添加一列predict_result,并保存到本地。这样backtrader就可以通过加载本地文件,完成基于深度学习的回测。
# 输出到文件
data_file = './baostock/data_ext/{}.csv'.format(stk_code)
out_df = pd.read_csv(data_file)
out_df = out_df[(out_df['date'] > '2017-12-31') & (out_df['date'] <= '2020-06-30')]
out_df['predict_result'] = results
out_df.to_csv('./baostock/predict_results/{}res.csv'.format(stk_code), index = False)
最后,还是记得在每只股票完成预测后,清理内存,以防内存被刷爆。
# 清理内存
backend.clear_session()
以上就完成了加载本地模型进行预测的过程,完整代码如下。下一篇文章将记录如果使用预测结果,进行多股回测。
import tensorflow as tf
import numpy as np
import pandas as pd
import os
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras import backend
stk_code_file = './stk_data/dp_stock_list.csv'
stk_list = pd.read_csv(stk_code_file)['code'].tolist()
for stk_code in stk_list:
print('processing {} ...'.format(stk_code))
# 加载模型
loaded_model = keras.models.load_model('./model/{}'.format(stk_code))
# 读入数据
data_file = './baostock/prediction_data_pre/{}.csv'.format(stk_code)
in_df = pd.read_csv(data_file)
# 预测用的输入字典
temp_dict = {}
# 将数据导入输入字典
for i in range(in_df.shape[1]):
temp_dict[i] = in_df['{}'.format(i)].tolist()
input_dict = {name: tf.convert_to_tensor(value) for name, value in temp_dict.items()}
# 进行预测
predictions = loaded_model.predict(input_dict)
results = []
for i in range(in_df.shape[0]):
results.append(predictions[i][0])
# 输出到文件
data_file = './baostock/data_ext/{}.csv'.format(stk_code)
out_df = pd.read_csv(data_file)
out_df = out_df[(out_df['date'] > '2017-12-31') & (out_df['date'] <= '2020-06-30')]
out_df['predict_result'] = results
out_df.to_csv('./baostock/predict_results/{}res.csv'.format(stk_code), index = False)
# 清理内存
backend.clear_session()
博客内容只用于交流学习,不构成投资建议,盈亏自负!
个人博客:http://coderx.com.cn/(优先更新)
项目最新代码:https://gitee.com/sl/quant_from_scratch
欢迎大家转发、留言。有微信群用于学习交流,感兴趣的读者请扫码加微信!
如果认为博客对您有帮助,可以扫码进行捐赠,感谢!
微信二维码 | 微信捐赠二维码 |
---|---|