import pandas as pd
import matplotlib.pyplot as plt
import pymysql
import sys
from mpl_finance import candlestick2_ochl
from matplotlib.ticker import MultipleLocator
# 计算ema
def calEMA(df,term):
for i in range(len(df)):
if i==0: #第一天
df.loc[df.iloc[i].name,'EMA'] = df.loc[df.iloc[i].name,'close']
if i>0:
df.loc[df.iloc[i].name,'EMA'] = (term-1)/(term+1) * df.loc[df.iloc[i-1].name,'EMA'] + 2/(term+1) * df.loc[df.iloc[i].name,'close']
EMAlist = list(df['EMA'])
return EMAlist
# 计算macd
def calMACD(df,shortTerm=12,longTerm=26,DIFTerm=9):
shortEMA = calEMA(df,shortTerm)
longEMA = calEMA(df,longTerm)
df['DIF'] = pd.Series(shortEMA) - pd.Series(longEMA)
for i in range(len(df)):
if i==0:
df.loc[df.iloc[i].name,'DEA'] = df.loc[df.iloc[i].name,'DIF']
if i>0:
df.loc[df.iloc[i].name,'DEA'] = (DIFTerm-1)/(DIFTerm+1) * df.loc[df.iloc[i-1].name,'DEA'] + 2/(DIFTerm+1) * df.loc[df.iloc[i].name,'DIF']
df['MACD'] = 2*(df['DIF']-df['DEA'])
# return df[['date','DIF','DEA','MACD']]
return df
dbhost = 'localhost'
dbuser = 'root'
dbpass = 'rootroot'
dbname = 'pythonstock'
try:
db = pymysql.connect(host=dbhost,user=dbuser,password=dbpass,database=dbname)
except:
print('数据库连接错误')
sys.exit()
cursor = db.cursor()
# 查询========================================
cursor.execute('select * from stockInfo where stockCode="600895" order by date asc')
result = cursor.fetchall()
cols = cursor.description
col = []
for index in cols:
col.append(index[0])
df = pd.DataFrame(result,columns=col)
stockDataFrame = calMACD(df,12,26,9)
print(stockDataFrame)
figure,(axPrice,axMACD) = plt.subplots(2,sharex=True,figsize=(15,8))
candlestick2_ochl(ax=axPrice,opens=stockDataFrame['open'].values,closes=stockDataFrame['close'].values,highs=stockDataFrame['high'].values,lows=stockDataFrame['low'].values,width=0.75,colorup='red',colordown='green')
axPrice.set_title('600895K线和均线图')
stockDataFrame['close'].rolling(window=3).mean().plot(ax=axPrice,color="red",label="3日均线")
stockDataFrame['close'].rolling(window=5).mean().plot(ax=axPrice,color="blue",label="5日均线")
stockDataFrame['close'].rolling(window=10).mean().plot(ax=axPrice,color="green",label="10日均线")
axPrice.legend(loc="best")
axPrice.set_ylabel('价格(单位:元)')
axPrice.grid(True)
stockDataFrame['DEA'].plot(ax=axMACD,color="red",label="DEA")
stockDataFrame['DIF'].plot(ax=axMACD,color="blue",label="DIF")
plt.legend(loc="best")
for index,row in stockDataFrame.iterrows():
if(row['MACD']>0):
axMACD.bar(row['date'],row['MACD'],width=0.5,color="red")
else:
axMACD.bar(row['date'],row['MACD'],width=0.5,color="green")
axMACD.set_title('600895MACD图')
axMACD.grid(linestyle="-.")
major_xtics = stockDataFrame['date'][stockDataFrame.index%10==0]
axMACD.set_xticks(major_xtics)
plt.setp(plt.gca().get_xticklabels(),rotation=30)
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.show()