运用Matplotlib annotate()函数作注解并在Subplot上绘图
首先可以看下想要实现的效果:
以下代码将会根据标普500指数收盘价格绘制一张曲线图,并标出2008年到2009年金融危机期间的一些重要日期。
1、导入库
from datetime import datetime
import pandas_datareader.data as web
import pandas as pd
import matplotlib.pyplot as plt
2、获取标普500指数数据
data = web.get_data_yahoo('^GSPC', start='1983-01-01', end='2019-08-22')
data.tail()
# 不妨用tail()命令了解输出的DataFrame的结构
# 需要注意从这个库中获取的数据,已经将Date作为了index
3、开始绘图
fig = plt.figure()
ax = fig.add_subplot(1, 1, 1)
注:实际上这两段代码等价于:fig,ax = plt.subplots()
,它的意思是建立一个fig
对象,建立一个axis
对象。个人理解是就像新建一个DataFrame
需要输入pd.Dataframe()
一样。
# 接下来开始绘制标普500收盘价的走势,
spx = data['Close']
spx.plot(ax=ax, style='k-')
注:style='k-'
指的是绘制黑色实线,绘制黑色虚线则是 style='k--'
,而 ax
控制子图的位置,即 (1,1,1)
,详细的理解参考这篇博客的案例1:https://blog.csdn.net/huozi07/article/details/45868021,它很清楚地解释了 ax
控制子图位置的功能。
# 现在开始在指数走势图上标注重要的事件
# 先罗列出一下的重要事件以及对应的事件
crisis_date = [(datetime(2007, 10, 11), 'Peak of bull market'),
(datetime(2008, 3, 12), 'Bear Stearns Fails'),
(datetime(2008, 9, 15), 'Lehman Bankruptcy')]
# 然后写一个循环结构,批量将事件录进绘制的途中,其中使用到了annotate()函数
for date, label in crisis_date:
ax.annotate(label, xy=(date, spx.asof(date)+50),
xytext=(date, spx.asof(date)+200),
arrowprops=dict(facecolor='black'),
horizontalalignment='left', verticalalignment='top')
注:annotate()函数参数很多比较复杂,个人认为在股价走势图上简要地添加注解并不会用到太多的复杂功能,以后如果自己在绘制这方面的图中出现了其他需求将会对这篇博客中进行补充。
唯一需要注意以下五点:
date
和label
是两个Series
,label
的type
是str
xy
代表被注释的坐标点,二维元组形如(x,y);xytext
代表注释文本的坐标点,也是二维元组,如果没有写出的话,默认与xy相同asof()
函数可以理解为:最后一行不是NaN值的值通俗的说:假如我有一组数据,某个点的时候这个值是NaN,那就求这个值之前最近一个不是NaN的值是多少。可以参考这篇博客的例子:https://blog.csdn.net/maymay_/article/details/80252587。horizontalalignment
和verticalalignment
,则是控制用于标注的文本框相对于箭头位置,可以参考本例进行理解。在本例中设置为left
和top
。- 关于
arrowprops
或arrowstyle
的叙述,参考这篇博客:https://blog.csdn.net/leaf_zizi/article/details/82886755,可以将其与官方文档https://matplotlib.org/tutorials/text/annotations.html#id26进行对照学习。
# 放大至2007-2009
ax.set_xlim('2007-01-01', '2011-01-01')
ax.set_ylim([600, 1800])
ax.set_title('Important dates in 2008-2009 financial crisis')
plt.grid(True)
plt.show()
最后绘出的图形如下:
4、完整代码
# 根据标普500指数收盘价格绘制一张曲线图,并标出2008年到2009年金融危机期间的一些重要日期
# import quandl
# quandl.ApiConfig.api_key = 'iUy778LZbzD--********'
# data = quandl.get("WIKI/AAPL", start_date='2000-01-01', end_date='2019-08-22')
from datetime import datetime
import pandas_datareader.data as web
import pandas as pd
import matplotlib.pyplot as plt
# pd.set_option('display.max_rows', 5000)
pd.set_option('display.max_columns', 5000)
pd.set_option('expand_frame_repr', False)
fig = plt.figure()
ax = fig.add_subplot(1, 1, 1)
# ===获取数据
data = web.get_data_yahoo('^GSPC', start='1983-01-01', end='2019-08-22')
# data['Date'] = pd.to_datetime(data['Date'])
print(data.tail())
spx = data['Close']
spx.plot(ax=ax, style='k-')
# ===事件整理
crisis_date = [(datetime(2007, 10, 11), 'Peak of bull market'),
(datetime(2008, 3, 12), 'Bear Stearns Fails'),
(datetime(2008, 9, 15), 'Lehman Bankruptcy')]
for date, label in crisis_date:
ax.annotate(label, xy=(date, spx.asof(date)+50),
xytext=(date, spx.asof(date)+200),
arrowprops=dict(facecolor='black'),
horizontalalignment='left', verticalalignment='top')
# 放大至2007-2009
ax.set_xlim('2007-01-01', '2011-01-01')
ax.set_ylim([600, 1800])
ax.set_title('Important dates in 2008-2009 financial crisis')
plt.grid(True)
plt.show()