实验数据
from pandas import DataFrame
import pandas as pd; import numpy as np
import matplotlib.pyplot as plt
from matplotlib import dates as mdates
from matplotlib import ticker as mticker
from matplotlib.finance import candlestick_ohlc
from matplotlib.dates import DateFormatter
import datetime as dt
import pylab
import datetime
import talib
数据预处理
time_format ="%Y/%m/%d"
col_names = ['datetime','Open', 'High', 'Low','Close', 'Volume','Amount']
dtypes = ["object", "float", "float", "float", "float", "float",'float']
mydata = np.genfromtxt("SH600006001.csv", delimiter=",", names=col_names,skip_header=2,dtype=dtypes,encoding='gbk')
Date=[datetime.datetime.strptime(i.decode("gbk"), time_format) for i in mydata['datetime']]
data1=DataFrame(mydata)
data2=DataFrame(Date,columns=['DateTime'])
data=pd.concat([data1,data2],axis=1)
df_new= data.set_index(['DateTime'])
del df_new['datetime']
df_new['datetime']=df_new.index
df_new = df_new[['datetime','Open','High','Low','Close','Volume' ,'Amount']]
df_new.tail(3)
绘图
def dark_1():
color1='#07000d' #背景
color2='#ff1717' #阳线
color3='#53c156' #阴线
color4='#e1edf9' #移动平均线
color5='#4ee6fd' #移动平均线
color6='#5998ff' #边框
color7='#c1f9f7' #RSI
color8='#386d13' #RSI范围下限30
color9='#8f2020' #RSI范围上限70
color10='#00ffe8' #成交量 & MACD
color11='w' # xy
color12='w' # 网格线
return color1,color2,color3,color4,color5,color6,color7,color8,color9,color10,color11,color12
def dark_2():
color1='#1b2431' #背景
color2='#ff1717' #阳线
color3='#53c156' #阴线
color4='#e1edf9' #移动平均线
color5='#4ee6fd' #移动平均线
color6='#5998ff' #边框
color7='#c1f9f7' #RSI
color8='#386d13' #RSI范围下限30
color9='#8f2020' #RSI范围上限70
color10='#00ffe8' #成交量 & MACD
color11='w' # xy
color12='w' # 网格线
return color1,color2,color3,color4,color5,color6,color7,color8,color9,color10,color11,color12
def light_1():
color1='white' #背景
color2='tomato' #阳线
color3='teal' #阴线
color4='#5c7a29' #移动平均线
color5='#6a3427' #移动平均线
color6='#5998ff' #边框
color7='#412f1f' #RSI
color8='#386d13' #RSI范围下限30
color9='#8f2020' #RSI范围上限70
color10='#afb4db' #成交量 & MACD
color11='black' # xy
color12='grey' # 网格线
return color1,color2,color3,color4,color5,color6,color7,color8,color9,color10,color11,color12
def light_2():
color1='#E6E6FA' #背景
color2='maroon' #阳线
color3='#11264f' #阴线
color4='#5F9EA0' #移动平均线
color5='#DB7093' #移动平均线
color6='#5998ff' #边框
color7='#2e3a1f' #RSI
color8='#386d13' #RSI范围下限30
color9='#8f2020' #RSI范围上限70
color10='#7bbfea' #成交量 & MACD
color11='black' # xy
color12='grey' # 网格线
return color1,color2,color3,color4,color5,color6,color7,color8,color9,color10,color11,color12
def light_3():
color1='honeydew' #背景
color2='gold' #阳线
color3='darkgreen' #阴线
color4='brown' #移动平均线
color5='darkslategray' #移动平均线
color6='black' #边框
color7='darkslategray' #RSI
color8='darkgreen' #RSI范围下限30
color9='gold' #RSI范围上限70
color10='cadetblue' #成交量 & MACD
color11='black' # xy
color12='grey' # 网格线
return color1,color2,color3,color4,color5,color6,color7,color8,color9,color10,color11,color12
def light_4():
color1='#fffef9' #背景
color2='crimson' #阳线
color3='darkgreen' #阴线
color4='#008080' #移动平均线
color5='#FFA500' #移动平均线
color6='black' #边框
color7='#4169E1' #RSI
color8='#386d13' #RSI范围下限30
color9='#8f2020' #RSI范围上限70
color10='silver' #成交量 & MACD
color11='black' # xy
color12='grey' # 网格线
return color1,color2,color3,color4,color5,color6,color7,color8,color9,color10,color11,color12
def light_5():
color1='#EDEDED' #背景
color2='brown' #阳线
color3='darkslategray' #阴线
color4='#145b7d' #移动平均线
color5='#973c3f' #移动平均线
color6='black' #边框
color7='#1d0200' #RSI
color8='brown' #RSI范围下限30
color9='darkslategray' #RSI范围上限70
color10='#CDB7B5' #成交量 & MACD
color11='black' # xy
color12='grey' # 网格线
return color1,color2,color3,color4,color5,color6,color7,color8,color9,color10,color11,color12
def style(style):
if style=='dark_1':
return dark_1()
elif style=='dark_2':
return dark_2()
elif style=='light_1':
return light_1()
elif style=='light_2':
return light_2()
elif style=='light_3':
return light_3()
elif style=='light_4':
return light_4()
elif style=='light_5':
return light_5()
def scope(start,end):
days=df_new[start:end]
return days
MA1 = 10
MA2 = 50
def drawing(start,end,Style):
color1,color2,color3,color4,color5,color6,color7,color8,color9,color10,color11,color12=style(Style)
days = scope(start,end)
daysreshape = days.reset_index()
daysreshape['DateTime']=mdates.date2num(daysreshape['datetime'].astype(dt.date))
daysreshape.drop('Volume', axis=1, inplace = True)
daysreshape.drop('Amount', axis=1, inplace = True)
daysreshape = daysreshape.reindex(columns=['DateTime','Open','High','Low','Close'])
Av1 = pd.rolling_mean(daysreshape.Close.values, MA1)
Av2 = pd.rolling_mean(daysreshape.Close.values, MA2)
SP = len(daysreshape.DateTime.values[MA2-1:])
fig = plt.figure(facecolor=color1,figsize=(15,10))
maLeg = plt.legend(loc=9, ncol=2, prop={'size':7},
fancybox=True, borderaxespad=0.)
maLeg.get_frame().set_alpha(0.4)
textEd = pylab.gca().get_legend().get_texts()
pylab.setp(textEd[0:5], color = color11)
#K线图
ax1 = plt.subplot2grid((6,4), (1,0), rowspan=4, colspan=4, axisbg=color1)
candlestick_ohlc(ax1, daysreshape.values[-SP:], width=.6, colorup=color2, colordown=color3)
Label1 = str(MA1)+' SMA'
Label2 = str(MA2)+' SMA'
ax1.plot(daysreshape.DateTime.values[-SP:],Av1[-SP:],color4,label=Label1, linewidth=1.5)
ax1.plot(daysreshape.DateTime.values[-SP:],Av2[-SP:],color5,label=Label2, linewidth=1.5)
ax1.grid(True, color=color12)
ax1.xaxis.set_major_locator(mticker.MaxNLocator(10))
ax1.xaxis.set_major_formatter(mdates.DateFormatter('%Y-%m-%d'))
ax1.yaxis.label.set_color(color11)
ax1.spines['bottom'].set_color(color6)
ax1.spines['top'].set_color(color6)
ax1.spines['left'].set_color(color6)
ax1.spines['right'].set_color(color6)
ax1.tick_params(axis='y', colors=color11)
plt.gca().yaxis.set_major_locator(mticker.MaxNLocator(prune='upper'))
ax1.tick_params(axis='x', colors=color11)
plt.ylabel('Stock price and Volume')
#RSI
ax0 = plt.subplot2grid((6,4), (0,0), sharex=ax1, rowspan=1, colspan=4, axisbg=color1)
rsi = talib.RSI(daysreshape.Close.values)
rsiCol = color7
posCol = color8
negCol = color9
ax0.plot(daysreshape.DateTime.values[-SP:], rsi[-SP:], rsiCol, linewidth=1.5)
ax0.axhline(70, color=negCol)
ax0.axhline(30, color=posCol)
ax0.fill_between(daysreshape.DateTime.values[-SP:], rsi[-SP:], 70, where=(rsi[-SP:]>=70), facecolor=negCol, edgecolor=negCol, alpha=0.5)
ax0.fill_between(daysreshape.DateTime.values[-SP:], rsi[-SP:], 30, where=(rsi[-SP:]<=30), facecolor=posCol, edgecolor=posCol, alpha=0.5)
ax0.set_yticks([30,70])
ax0.yaxis.label.set_color(color11)
ax0.spines['bottom'].set_color(color6)
ax0.spines['top'].set_color(color6)
ax0.spines['left'].set_color(color6)
ax0.spines['right'].set_color(color6)
ax0.tick_params(axis='y', colors=color11)
ax0.tick_params(axis='x', colors=color11)
plt.ylabel('RSI')
#成交量
volumeMin = 0
ax1v = ax1.twinx()
ax1v.fill_between(daysreshape.DateTime.values[-SP:],volumeMin, days.Volume.values[-SP:], facecolor=color10, alpha=.4)
ax1v.axes.yaxis.set_ticklabels([])
ax1v.grid(False)
ax1v.set_ylim(0, 3*days.Volume.values.max())
ax1v.spines['bottom'].set_color(color6)
ax1v.spines['top'].set_color(color6)
ax1v.spines['left'].set_color(color6)
ax1v.spines['right'].set_color(color6)
ax1v.tick_params(axis='x', colors=color11)
ax1v.tick_params(axis='y', colors=color11)
#MACD
ax2 = plt.subplot2grid((6,4), (5,0), sharex=ax1, rowspan=1, colspan=4, axisbg=color1)
fillcolor = color10
nslow = 26
nfast = 12
nema = 9
emaslow, emafast, macd = talib.MACD(daysreshape.Close.values)
ema9 = talib.EMA(macd, nema)
ax2.plot(daysreshape.DateTime.values[-SP:], macd[-SP:], color=color5, lw=2)
ax2.plot(daysreshape.DateTime.values[-SP:], ema9[-SP:], color=color4, lw=1)
ax2.fill_between(daysreshape.DateTime.values[-SP:], macd[-SP:]-ema9[-SP:], 0, alpha=0.5, facecolor=fillcolor, edgecolor=fillcolor)
plt.gca().yaxis.set_major_locator(mticker.MaxNLocator(prune='upper'))
ax2.spines['bottom'].set_color(color6)
ax2.spines['top'].set_color(color6)
ax2.spines['left'].set_color(color6)
ax2.spines['right'].set_color(color6)
ax2.tick_params(axis='x', colors=color11)
ax2.tick_params(axis='y', colors=color11)
plt.ylabel('MACD', color='w')
ax2.yaxis.set_major_locator(mticker.MaxNLocator(nbins=5, prune='upper'))
for label in ax2.xaxis.get_ticklabels():
label.set_rotation(45)
plt.savefig('{}.png'.format(Style))
plt.show()
print('请输入时间范围(1999/7/27~2018/3/30)')
#start=input('开始日期:')
#end=input('结束日期:')
print('请输入界面风格:\ndark_1\ndark_2\nlight_1\nlight_2\nlight_3\nlight_4\nlight_5')
Style=input()
start='2017/3/30'
end='2017/12/30'
drawing(start,end,Style)