公众号:尤而小屋
编辑:Peter
作者:Peter
大家好,我是Peter~
介绍一个基于深度学习实战项目:基于长短期记忆模型LSTM的股价预测,包含:
- 如何通过yfinance下载金融股票数据
- 成交量、收盘价可视化
- 如何生成股价的5日、10日平均值
- 股价日收益的计算
- 基于LSTM建模预测收盘价等
LSTM的介绍
1、https://easyai.tech/ai-definition/lstm/
2、https://zh.d2l.ai/chapter_recurrent-modern/lstm.html
导入库
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
sns.set_style('whitegrid')
plt.style.use("fivethirtyeight")
%matplotlib inline
from pandas_datareader.data import DataReader
# 专门用来获取金融股票数据的第三方包
import yfinance as yf
from pandas_datareader import data as pdr
yf.pdr_override()
from datetime import datetime
from sklearn.preprocessing import MinMaxScaler # 数据归一化
import warnings
warnings.filterwarnings("ignore")
生成数据
基于yfinance生成数据:
tech_list = ['AAPL', 'GOOG', 'MSFT', 'AMZN'] # 指定4个公司
end = datetime.now() # 股票时间设置
start = datetime(end.year - 3, end.month, end.day)
for stock in tech_list: # tech_list = ['AAPL', 'GOOG', 'MSFT', 'AMZN']
globals()[stock] = yf.download(stock, start, end) # 指定公司名称+时间
company_list = [AAPL, GOOG, MSFT, AMZN]
company_name = ["APPLE", "GOOGLE", "MICROSOFT", "AMAZON"]
for company, com_name in zip(company_list, company_name):
company["company_name"] = com_name
df = pd.concat(company_list, axis=0)
df.tail()
数据信息
AAPL.head() # 4个DataFrame: 'AAPL', 'GOOG', 'MSFT', 'AMZN'
AAPL.info()
<class 'pandas.core.frame.DataFrame'>
DatetimeIndex: 752 entries, 2021-09-03 to 2024-08-30
Data columns (total 7 columns):
# Column Non-Null Count Dtype
--- ------ -------------- -----
0 Open 752 non-null float64
1 High 752 non-null float64
2 Low 752 non-null float64
3 Close 752 non-null float64
4 Adj Close 752 non-null float64
5 Volume 752 non-null int64
6 company_name 752 non-null object
dtypes: float64(5), int64(1), object(1)
memory usage: 47.0+ KB
收盘价 Closing Price
plt.figure(figsize=(15, 10))
plt.subplots_adjust(top=1.25, bottom=1.2)
for i, company in enumerate(company_list, 1):
plt.subplot(2,2,i)
company["Adj Close"].plot()
plt.ylabel("Adj Close")
plt.xlabel(None)
plt.title(f"Closing Price of {
tech_list[i - 1]}")
plt.tight_layout()
成交量 Volume of Sales
plt.figure(figsize=(15, 10))
plt.subplots_adjust(top=1.25, bottom=1.2)
for i, company in enumerate(company_list, 1):
plt.subplot(2,2,i)
company["Volume"].plot