import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import statsmodels.api as sm
from statsmodels.tsa.arima_model import ARMA
from statsmodels.graphics.api import qqplot
import warnings
from itertools import product
from datetime import datetime, timedelta
import tushare as ts
warnings.filterwarnings('ignore')
pro = ts.pro_api(token='%your TOKEN')
#查询当前所有正常上市交易的股票列表
df = pro.daily(ts_code='002551.SZ', start_date='20200201', end_date='20200426')
df.sort_values(by="trade_date" , ascending=False)
dfOpen = df[['open']]
dfOpen = df[['open']].sort_index(ascending=False)
dfOpen.to_csv('~/stockData/002551.csv')
data = dfOpen['open']
data = np.array(data,dtype=np.float)
data = pd.Series(data)
data.index = pd.Index(sm.tsa.datetools.dates_from_range('1942','2000')) #sm.tsa.datetools.dates_from_range(pd.datetime.strptime('1001-01', '%Y-%m'), pd.datetime.strptime('1059-1', '%Y-%m'))
#绘制数据图
data.plot(figsize=(12,8))
plt.show()
#创建ARMA模型
arma = ARMA(data, (7,0)).fit()
print('AIC: %0.4lf' %arma.aic)
#模型预测
predict_y = arma.predict('2000', '2010', dynamic=True)
print(predict_y)
#预测结果绘制
fig, ax = plt.subplots(figsize=(12, 8))
ax = data.ix['1942':].plot(ax=ax)
fig = arma.plot_predict('2000', '2010', dynamic=True, ax=ax, plot_insample=False)
plt.show()