目录
简介
准备写个系列博客介绍机器学习实战中的部分公开项目。首先从初级项目开始。
本文为初级项目第三篇:利用NSE-TATA数据集预测股票价格。
项目原网址为:Stock Price Prediction – Machine Learning Project in Python。
第一篇为:机器学习实战 | emojify 使用Python创建自己的表情符号(深度学习初级)
第二篇为:机器学习实战 | MNIST手写数字分类项目(深度学习初级)
技术流程
项目构想:
机器学习在股票价格预测中具有重要应用。在这个机器学习项目中,我们将讨论预测股票价格。这是一项非常复杂的任务,并且具有不确定性。
我们将学习如何使用 LSTM 神经网络
预测股票价格。
1. 载入依赖包
import matplotlib
matplotlib.use('Qt5Agg') # 防止画图时画图软件崩溃
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.pylab import rcParams
rcParams['figure.figsize'] = 20, 10 # 设置画板尺寸
from keras.models import Sequential
from keras.layers import LSTM, Dropout, Dense
from sklearn.preprocessing import MinMaxScaler
项目中主要用了pandas
、sklearn
、Keras
和TensorFlow
包,pandas
和sklearn
安装命令为:
pip install pandas
pip install scikit-learn
Keras
和TensorFlow
的安装命令为:
pip install keras==2.10.0
pip install TensorFlow==2.10.0
在最后输出结果的时候发现每次画图软件都崩溃导致程序中断,解决办法就是在前面加上这句话:matplotlib.use('Qt5Agg')
,防止画图时画图软件崩溃。
2. 读取数据集
df = pd.read_csv("NSE-TATA.csv") # 读取.csv文件
df.head() # 默认只读取dataframe数据表中前5行内容
为了构建股票价格预测模型,我们将使用 NSE-TATA数据集。这是来自印度国家证券交易所塔塔全球饮料有限公司的塔塔饮料数据集,官方网址可能不好下载,这里给出了数据集下载地址:NSE-TATA数据集。
- df.head():读取dataframe数据表,默认只读取dataframe数据表中前5行内容
3. 从数据集中分析价格
df["Date"] = pd.to_datetime(df.Date, format="%Y-%m-%d") # 将一个字符串解析为时间,并指定字符串的格式
df.index = df['Date']
plt.figure(figsize=(8, 4)) # 指定图片大小
plt.plot(df["Close"], label='Close Price history') # 绘图展示历史数据
- pd.to_datetime:将字符串解析为时间,并指定字符串的格式
- plt.plot: 绘图展示历史数据,绘图结果为:
4. 对数据排序
data = df.sort_index(ascending=True, axis=0) # 索引排序:默认按行从小到大
new_dataset = pd.DataFrame(index=range(0, len(df)), columns=['Date', 'Close']) # 创建新的数据集
for i in range(0, len(data)):
new_dataset["Date"][i] = data['Date'][i]
new_dataset["Close"][i] = data["Close"