import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
data = pd.read_csv(r'C:\Users\28715\Desktop\predict\LSTM_stock\picture\predict_values\values.csv')
figsize = 20, 10
fontsize = 15
alpha = 1
smooth_alpha = 3
def smooth(y, box_pts):
box = np.ones(box_pts)/box_pts
y_smooth = np.convolve(y, box, mode='valid')
return y_smooth
def see_SZZS(data):
code = 'SZZS'
lstm = data['Lstm'+code]
rnn = data['Rnn'+code]
rl = data['Rl'+code]
value = data['RealSZZS']
#plt.figure(1,figsize=(40,20))
# 设置输出的图片大小
#figsize = 30, 20
figure, ax = plt.subplots(figsize=figsize)
# 设置坐标刻度值的大小以及刻度值的字体
plt.tick_params(labelsize=fontsize)
ax.get_xticklabels() + ax.get_yticklabels()
# 设置图例并且设置图例的字体及大小
font = {'family': 'Times New Roman',
'weight': 'normal',
'size': fontsize,
}
plt.grid(b='major',color='k',linestyle='-', linewidth=0.5,alpha=0.4)
plt.plot(smooth(value,smooth_alpha),'k',alpha=alpha)
plt.plot(smooth(lstm,smooth_alpha),'b',alpha=alpha)
plt.plot(smooth(rnn,smooth_alpha),'g',alpha=alpha)
plt.plot(smooth(rl,smooth_alpha),'r',alpha=alpha)
plt.legend([code+' Index','LSTM','RNN','PPORL'],prop=font)
plt.xlabel('Trading days',fontsize=fontsize)
#plt.title('Pendulum-v0')
plt.ylabel('Index price',fontsize=fontsize)
plt.show()
def see_SZCZ(data):
code = 'SZCZ'
lstm = data['Lstm'+code]
rnn = data['Rnn'+code]
rl = data['Rl'+code]
value = data['RealSZCZ']
# 设置输出的图片大小
#figsize = 30, 20
figure, ax = plt.subplots(figsize=figsize)
# 设置坐标刻度值的大小以及刻度值的字体
plt.tick_params(labelsize=fontsize)
ax.get_xticklabels() + ax.get_yticklabels()
# 设置图例并且设置图例的字体及大小
font = {'family': 'Times New Roman',
'weight': 'normal',
'size': fontsize,
}
plt.grid(b='major',color='k',linestyle='-', linewidth=0.5,alpha=0.4)
plt.plot(smooth(value,smooth_alpha),'k',alpha=alpha)
plt.plot(smooth(lstm,smooth_alpha),'b',alpha=alpha)
plt.plot(smooth(rnn,smooth_alpha),'g',alpha=alpha)
plt.plot(smooth(rl,smooth_alpha),'r',alpha=alpha)
plt.legend([code+' Index','LSTM','RNN','PPORL'],prop=font)
plt.xlabel('Trading days',fontsize=fontsize)
#plt.title('Pendulum-v0')
plt.ylabel('Index price',fontsize=fontsize)
plt.show()
def see_HS300(data):
code = 'HS300'
lstm = data['Lstm'+code]
rnn = data['Rnn'+code]
rl = data['Rl'+code]
value = data['RealHS300']
# 设置输出的图片大小
#figsize = 30, 20
figure, ax = plt.subplots(figsize=figsize)
# 设置坐标刻度值的大小以及刻度值的字体
plt.tick_params(labelsize=fontsize)
ax.get_xticklabels() + ax.get_yticklabels()
# 设置图例并且设置图例的字体及大小
font = {'family': 'Times New Roman',
'weight': 'normal',
'size': fontsize,
}
plt.grid(b='major',color='k',linestyle='-', linewidth=0.5,alpha=0.4)
plt.plot(smooth(value,smooth_alpha),'k',alpha=alpha)
plt.plot(smooth(lstm,smooth_alpha),'b',alpha=alpha)
plt.plot(smooth(rnn,smooth_alpha),'g',alpha=alpha)
plt.plot(smooth(rl,smooth_alpha),'r',alpha=alpha)
plt.legend([code+' Index','LSTM','RNN','PPORL'],prop=font)
plt.xlabel('Trading days',fontsize=fontsize)
#plt.title('Pendulum-v0')
plt.ylabel('Index price',fontsize=fontsize)
plt.show()
see_SZZS(data)
see_SZCZ(data)
see_HS300(data)