Python模块 —— Matplotlib
Matplotlib(二)—— 子图
大家可以关注知乎或微信公众号的share16,我们也会同步更新此文章。
四、子图
4.1 均匀子图
4.1.1 plt.subplots
fig,ax = plt.subplots(nrows, ncols, figsize, sharex, sharey, squeeze, subplot_kw, gridspec_kw, fig_kw)
- 返回值:type(fig)是matplotlib.figure.Figure,type(ax)是numpy.ndarray;
- nrows/ncols:创建行数*列数的子图;
- figsize:指定整个画布的大小,以元组形式输入;
- sharex/sharey:是否共享x/y轴的属性; squeeze:默认True;
- subplot_kw:带有关键字的字典,该关键字传递给用于创建每个子图的add_subplot调用;
- gridspec_kw:带有关键字的字典,该关键字传递给用于创建每个子图网格的GridSpec函数;
4.1.2 plt.subplot
plt.subplot(nrows, ncols, index, projection, **kwargs)
- projection:可取值{None / aitoff / hammer / lambert / mollweide / polar / rectilinear }
plt.tight_layout()
调整子图的相对大小使字符不会重叠
4.2 非均匀子图
所谓非均匀包含两层含义,第一是指图的比例大小不同但没有跨行或跨列,第二是指图为跨列或跨行状态。
4.2.1 fig.add_gridspec
spec = fig.add_gridspec(nrows, ncols, width_ratios, height_ratios, **kwargs)
- 没有 plt.add_gridspec 函数;
- width_ratios/height_ratios:列表里的元素数与列数(行数)相等,代表各列(行)的宽度(高度)比例;
ax_k = fig.add_gridspec(spec[i,j])
spec[i,j]类似于索引
4.3 子图上的方法
fig,ax = plt.subplots(nrows, ncols, figsize, ···)
- ax对象上定义了和plt类似的图形绘制函数,常用的有:plot、hist、scatter、bar、barh、pie;
- 常用直线(水平/垂直/任意方向)的画法,依次为:ax.axhline(y,xmin,xmax)、ax.axvline(x,ymin,ymax)、ax.axline((x1,y1),(x2,y2))或ax.axline((x1,y1),slope) slope是斜率;
- 设置坐标轴的规度(对数坐标等)、标题、轴名,依次为:set_xscale、set_title、set_xlabel;
- 添加网格、图例:ax.grid(alpha=0.5,color=‘m’)、ax.legend();
图例
:放在plot函数中,方能显示全部字符,若放置在legend中,则只显示一个字符; - 添加文本注释:ax.annotate(s,xy,xytext,arrowprops,···),(s:注释的文本、xy:要注释的点、xytext:放置文本的位置、arrowprops:dict类型);
- 添加箭头、将文本添加到(x,y)处:ax.arrow、ax.text(x,y,s,···);
4.4 墨尔本温度数据集
绘制出墨尔本1981~1990年的温度曲线,数据集点此下载
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
df = pd.read_csv('/xxx/08 墨尔本温度.csv')
# Series操作时,转变成字符串
df['year'] = df.Time.str.split('-',expand=True)[0]
df.set_index('year',inplace=True)
rows,cols = 2,5
year = np.array(df.index.unique()).reshape(rows,cols)
fig,ax = plt.subplots(rows,cols,figsize=(15,4),sharex=True,sharey=True)
fig.suptitle('墨尔本1981~1990年温度曲线')
for i in range(2):
for j in range(5):
k = str(year[i][j])
t = df.loc[k].Temperature
ax[i][j].plot(range(1,len(t)+1),t,marker='*')
ax[i][j].set_title(k+'年')
if j == 0:
ax[i][j].set_ylabel('气温')
fig.tight_layout()
4.5 画出数据的散点图和边际分布图
import matplotlib.pyplot as plt
import numpy as np
x,y = np.random.randn(2,150)
fig = plt.figure(figsize=(6,4))
spec = fig.add_gridspec(2,2,width_ratios=[5,1],height_ratios=[1,6])
ax1 = fig.add_subplot(spec[1,0])
ax1.scatter(x,y)
ax1.grid()
ax2 = fig.add_subplot(spec[0,0])
ax2.hist(x,bins=10)
ax2.axis('off') # 隐藏坐标系
ax3 = fig.add_subplot(spec[1,1])
ax3.hist(y,bins=10,orientation='horizontal') # 直方图变成水平方向
ax3.axis('off') # 隐藏坐标系
fig.tight_layout()
plt.subplots_adjust(wspace=0, hspace=0) # 调整子图间距
谢谢大家 🌹