基于上一个版本的arima 算法进行改进,该次修改只修改了预测部分的算法从,arima 算法改为lstm 算法
现在的预测结果:
之前的预测结果:
从结果的图片上可以看出LSTM 更贴近真实的数据曲线,准确性也相比ARIMA 算法要好,但是LSTM 的运行时间较长:
而AIRMA 的运行时间很短。在代码的复杂度上LSTM 也是要高于ARIMA
def get_stock_hist_realtimedeal(code, level):
"""
历史分时数据接口
:param code: 股票代码
:param level: 分时级别
:return:
"""
try:
if level in ['5', '15', '30', '60', 'Day', 'Week', 'Month', 'Year']:
return resp(data=StockApi.get_stock_hist_realtimedeal(code, level))
elif level == 'Forecast':
data = pandas.DataFrame(StockApi.get_stock_hist_realtimedeal(code, 'Day'))
data.to_csv('./data/Day.csv', index=False)
data = pd.read_csv('./data/Day.csv', index_col=0, parse_dates=[0])
stock_week = data['o'].resample('W').mean().dropna()
stock_train = stock_week['2020':'2023'].dropna()
stock_data = stock_week['2023':'2024'].dropna()
scaler = MinMaxScaler(feature_range=(0, 1))
stock_train_scaled = scaler.fit_transform(stock_train.values.reshape(-1, 1))
# Preparing dataset for training
X_train = []
y_train = []
for i in range(60, len(stock_train_scaled)):
X_train.append(stock_train_scaled[i - 60:i, 0])
y_train.append(stock_train_scaled[i, 0])
X_train, y_train = np.array(X_train), np.array(y_train)
X_train = np.reshape(X_train, (X_train.shape[0], X_train.shape[1], 1))
# Build the LSTM model
model = Sequential()
model.add(LSTM(units=50, return_sequences=True, input_shape=(X_train.shape[1], 1)))
model.add(LSTM(units=50))
model.add(Dense(1))
model.compile(optimizer='adam', loss='mean_squared_error')
model.fit(X_train, y_train, epochs=100, batch_size=32)
# Forecasting
inputs = stock_week[len(stock_week) - len(stock_data) - 60:].values
inputs = inputs.reshape(-1, 1)
inputs = scaler.transform(inputs)
X_test = []
for i in range(60, len(inputs)):
X_test.append(inputs[i - 60:i, 0])
X_test = np.array(X_test)
X_test = np.reshape(X_test, (X_test.shape[0], X_test.shape[1], 1))
predicted_stock_price = model.predict(X_test)
predicted_stock_price = scaler.inverse_transform(predicted_stock_price)
# Converting predictions to JSON
pred_list = predicted_stock_price.flatten().tolist()
predictions = [{'date': str(index.date()), 'val': format(val, '.2f')} for index, val in
zip(stock_data.index, pred_list)]
if isinstance(stock_data.index, pd.DatetimeIndex):
stock_data.index = stock_data.index.strftime('%Y-%m-%d')
if isinstance(stock_data, pd.Series):
origin_data = [{'date': index, 'val': format(value, '.2f')} for index, value in stock_data.items()]
else:
origin_data = [{'date': index, 'val': format(value, '.2f')} for index, value in
stock_data.itertuples(index=True, name=None)]
predictions_data = [{'date': str(item['date']), 'val': item['val']} for item in predictions]
obj = {'origin': json.dumps(origin_data), 'pred': json.dumps(predictions_data)}
print(obj)
# Generating the response object
return resp(data=json.dumps(obj))
else:
return resp(ResponseEnum.STOCK_LEVEL_INVALID.value['code'], ResponseEnum.STOCK_LEVEL_INVALID.value['msg'])
except Exception as e:
print(e)
return resp(ResponseEnum.OUTER_INTERFACE_EXCEPTION.value['code'],
ResponseEnum.OUTER_INTERFACE_EXCEPTION.value['msg'])