学习LSTM看到的一个有意思的例子,记录一下。
import math
import yfinance as yf
import numpy as np
import pandas as pd
from sklearn.preprocessing import MinMaxScaler
import matplotlib.pyplot as plt
from keras.models import Sequential
from keras.layers import Dense, LSTM
# 获取茅台(600519.SS)股票数据
df = yf.download('600519.SS', start='2012-01-01', end='2024-12-10')
# 可视化股票数据
plt.figure(figsize=(16, 8))
plt.title('Kweichow Moutai Stock Price')
plt.plot(df['Close'])
plt.xlabel('Date', fontsize=18)
plt.ylabel('Closing Price CNY', fontsize=18)
plt.show()
# 只使用收盘价列
data = df[['Close']] # 更加安全地获取 'Close' 列
dataset = data.values
# 数据预处理
Scaler = MinMaxScaler(feature_range=(0, 1))
scaled_data = Scaler.fit_transform(dataset)
# 划分训练集
training_data_len = math.ceil(len(dataset) * 0.8)
# 训练集数据
train_data = scaled_data[0:training_data_len, :]
x_train = []
y_train = []
for i in range(60, len(train_data)):
x_train.append(train_data[i - 60:i, 0])
y_train.append(train_data[i, 0])
# 转换为NumPy数组
x_train, y_train = np.array(x_train), np.array(y_train)
# 将数据调整为LSTM所需的三维格式
x_train = np.reshape(x_train, (x_train.shape[0], x_train.shape[1], 1))
# 创建LSTM模型
model = Sequential()
model.add(LSTM(50, return_sequences=True, input_shape=(x_train.shape[1], 1)))
model.add(LSTM(50, return_sequences=False))
model.add(Dense(25))
model.add(Dense(1))
# 编译模型
model.compile(optimizer='adam', loss='mean_squared_error')
# 训练模型
model.fit(x_train, y_train, batch_size=1, epochs=1)
# 评估模型
test_data = scaled_data[training_data_len - 60:, :]
x_test = []
y_test = dataset[training_data_len:, :]
for i in range(60, len(test_data)):
x_test.append(test_data[i - 60:i, 0])
x_test = np.array(x_test)
# 转换为LSTM所需的三维格式
x_test = np.reshape(x_test, (x_test.shape[0], x_test.shape[1], 1))
# 做出预测
predictions = model.predict(x_test)
predictions = Scaler.inverse_transform(predictions)
# 计算RMSE
rmse = np.sqrt(np.mean((predictions - y_test) ** 2))
print(rmse)
# 可视化训练集与验证集的结果
train = data[:training_data_len]
valid = data[training_data_len:]
valid['predictions'] = predictions
# 绘制预测图
plt.figure(figsize=(16, 8))
plt.title('Kweichow Moutai Stock Price Prediction Model')
plt.plot(train['Close'])
plt.plot(valid[['Close', 'predictions']])
plt.legend(['Train', 'val', 'predictions'], loc='lower right')
plt.xlabel('Date', fontsize=18)
plt.ylabel('Closing Price CNY', fontsize=18)
plt.show()