目录
1、普通绘制热力图
# -*- coding:utf-8 _*-
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
# 创建数据
data = np.random.random((7,12))
# 计算相关性
corr = np.corrcoef(data)
# 设置需要显示的标签
bands_wavelength = ["400","500","600","700","800","900","1000"]
mask = np.zeros_like(corr,dtype=np.bool_)
mask[np.tril_indices_from(mask)] = True
cmap = sns.diverging_palette(220,10,as_cmap=True)
corr = np.flip(corr, axis=0)
mask = np.flip(mask, axis=0)
ax = sns.heatmap(corr,mask=mask.T,cmap=cmap,square=True,linewidths=0.5,
vmin=np.min(corr), vmax=np.max(corr),cbar=True, # vmin和vmax是自定义显示颜色的范围
xticklabels=bands_wavelength, yticklabels=bands_wavelength[::-1])
# plt.title("平方数", fontsize=24) # 设置标题
# plt.xlabel("值", fontsize=14) # 设置x标题
# plt.ylabel("值的平方", fontsize=14) # 设置y轴标题
# plt.savefig('./热力图.png',bbox_inches='tight',dpi=300)
plt.show()
2、坐标轴标签太多,自定义标签显示
# -*- coding:utf-8 _*-
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
# 创建数据
data = np.random.random((100,12))
# 计算相关性
corr = np.corrcoef(data)
# 设置需要显示的标签
bands_wavelength = ["400","500","600","700","800","900","1000"]
mask = np.zeros_like(corr,dtype=np.bool_)
mask[np.tril_indices_from(mask)] = True
cmap = sns.diverging_palette(220,10,as_cmap=True)
corr = np.flip(corr, axis=0)
mask = np.flip(mask, axis=0)
ax = sns.heatmap(corr,mask=mask.T,cmap=cmap,square=True,linewidths=0.5,
vmin=np.min(corr), vmax=np.max(corr),cbar=True, # vmin和vmax是自定义显示颜色的范围
xticklabels=bands_wavelength, yticklabels=bands_wavelength[::-1])
dx = data.shape[0]/(len(bands_wavelength)-1)
ax.set_xticks([dx*i for i in range(len(bands_wavelength))])
ax.set_yticks([dx*i for i in range(len(bands_wavelength))])
ax.set_xticklabels(bands_wavelength)
ax.set_yticklabels(bands_wavelength[::-1])
# plt.title("平方数", fontsize=24) # 设置标题
# plt.xlabel("值", fontsize=14) # 设置x标题
# plt.ylabel("值的平方", fontsize=14) # 设置y轴标题
# plt.savefig('./热力图.png',bbox_inches='tight',dpi=300)
plt.show()
3、不显示热图的网格
# -*- coding:utf-8 _*-
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
# 创建数据
data = np.random.random((100,12))
# 计算相关性
corr = np.corrcoef(data)
# 设置需要显示的标签
bands_wavelength = ["400","500","600","700","800","900","1000"]
mask = np.zeros_like(corr,dtype=np.bool_)
mask[np.tril_indices_from(mask)] = True
cmap = sns.diverging_palette(220,10,as_cmap=True)
corr = np.flip(corr, axis=0)
mask = np.flip(mask, axis=0)
# 修改linewidths为0即可
ax = sns.heatmap(corr,mask=mask.T,cmap=cmap,square=True,linewidths=0.,
vmin=np.min(corr), vmax=np.max(corr),cbar=True, # vmin和vmax是自定义显示颜色的范围
xticklabels=bands_wavelength, yticklabels=bands_wavelength[::-1])
dx = data.shape[0]/(len(bands_wavelength)-1)
ax.set_xticks([dx*i for i in range(len(bands_wavelength))])
ax.set_yticks([dx*i for i in range(len(bands_wavelength))])
ax.set_xticklabels(bands_wavelength)
ax.set_yticklabels(bands_wavelength[::-1])
# plt.title("平方数", fontsize=24) # 设置标题
# plt.xlabel("值", fontsize=14) # 设置x标题
# plt.ylabel("值的平方", fontsize=14) # 设置y轴标题
# plt.savefig('./热力图.png',bbox_inches='tight',dpi=300)
plt.show()
4、自定义颜色条的距离、标签
方法1,推荐:
# -*- coding:utf-8 _*-
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
import matplotlib.ticker as tkr
# 创建数据
data = np.random.random((100, 12))
# 计算相关性
corr = np.corrcoef(data)
# 设置需要显示的标签
bands_wavelength = ["400", "500", "600", "700", "800", "900", "1000"]
mask = np.zeros_like(corr, dtype=np.bool_)
mask[np.tril_indices_from(mask)] = True
cmap = sns.diverging_palette(220, 10, as_cmap=True)
corr = np.flip(corr, axis=0)
mask = np.flip(mask, axis=0)
# ---------设置颜色条----------
cbar_ticks = [-1, -0.5, 0, 0.5, 1]
formatter = tkr.ScalarFormatter(useMathText=True)
formatter.set_scientific(False)
# ----------------------------
# cbar_kws的参数说明
# 'label': #color bar的名称
# 'ticks':#color bar中刻度值范围和间隔
# 'format':'%.0f',#格式化输出color bar中刻度值
# 'pad':0.15,#color bar与热图之间距离,距离变大热图会被压缩
ax = sns.heatmap(corr, mask=mask.T, cmap=cmap, square=True, linewidths=0.,
vmin=np.min(corr), vmax=np.max(corr), cbar=True, # vmin和vmax是自定义显示颜色的范围
xticklabels=bands_wavelength, yticklabels=bands_wavelength[::-1]
,cbar_kws={"ticks": cbar_ticks, "format": formatter,'label':'A','pad':0.01})
ax.collections[0].colorbar.ax.yaxis.set_ticks([], minor=True)
# 设置colorbar的刻度字体大小
cax = plt.gcf().axes[-1]
cax.tick_params(labelsize=11)
# 设置colorbar的label文本和字体大小
font1 = {'family':'Times New Roman','size':11, 'color':'#000000'} #热图以及colorbar的字体
cbar = ax.collections[0].colorbar
cbar.set_label('A',fontdict=font1)
# ---------设置自定义刻度标签----------
dx = data.shape[0] / (len(bands_wavelength) - 1)
ax.set_xticks([dx * i for i in range(len(bands_wavelength))])
ax.set_yticks([dx * i for i in range(len(bands_wavelength))])
ax.set_xticklabels(bands_wavelength)
ax.set_yticklabels(bands_wavelength[::-1])
# ----------------------------
# plt.title("平方数", fontsize=24) # 设置标题
# plt.xlabel("值", fontsize=14) # 设置x标题
# plt.ylabel("值的平方", fontsize=14) # 设置y轴标题
# plt.savefig('./热力图.png',bbox_inches='tight',dpi=300)
plt.show()
方法2:
# -*- coding:utf-8 _*-
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable
# 创建数据
data = np.random.random((100, 12))
# 计算相关性
corr = np.corrcoef(data)
# 设置需要显示的标签
bands_wavelength = ["400", "500", "600", "700", "800", "900", "1000"]
mask = np.zeros_like(corr, dtype=np.bool_)
mask[np.tril_indices_from(mask)] = True
cmap = sns.diverging_palette(220, 10, as_cmap=True)
corr = np.flip(corr, axis=0)
mask = np.flip(mask, axis=0)
ax = sns.heatmap(corr, mask=mask.T, cmap=cmap, square=True, linewidths=0.,
vmin=np.min(corr), vmax=np.max(corr), cbar=False, # vmin和vmax是自定义显示颜色的范围
xticklabels=bands_wavelength, yticklabels=bands_wavelength[::-1])
# ---------设置颜色条----------
divider = make_axes_locatable(ax)
cax = divider.append_axes("right", size="5%", pad=0.1)
cbar = plt.colorbar(ax.collections[0], cax=cax)
cbar.set_ticks([-0.667, 0, 0.667])
cbar.ax.set_yticklabels(['-1', '0', '1'], size=20)
cbar.ax.tick_params(axis='y', which='major', length=0, pad=15)
cbar.outline.set_edgecolor('black')
cbar.outline.set_linewidth(0.01)
# ----------------------------
# ---------设置自定义刻度标签----------
dx = data.shape[0] / (len(bands_wavelength) - 1)
ax.set_xticks([dx * i for i in range(len(bands_wavelength))])
ax.set_yticks([dx * i for i in range(len(bands_wavelength))])
ax.set_xticklabels(bands_wavelength)
ax.set_yticklabels(bands_wavelength[::-1])
# ----------------------------
# plt.title("平方数", fontsize=24) # 设置标题
# plt.xlabel("值", fontsize=14) # 设置x标题
# plt.ylabel("值的平方", fontsize=14) # 设置y轴标题
# plt.savefig('./热力图.png',bbox_inches='tight',dpi=300)
plt.show()
5、显示对角线的值
# -*- coding:utf-8 _*-
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
# 创建数据
data = np.random.random((7,12))
# 计算相关性
corr = np.corrcoef(data)
# 设置需要显示的标签
bands_wavelength = ["400","500","600","700","800","900","1000"]
mask = np.zeros_like(corr,dtype=np.bool_)
mask[np.tril_indices_from(mask)] = True
cmap = sns.diverging_palette(220,10,as_cmap=True)
corr = np.flip(corr, axis=0)
mask = np.flip(mask, axis=1)
ax = sns.heatmap(corr,mask=~mask.T,cmap=cmap,square=True,linewidths=0.5,
vmin=np.min(corr), vmax=np.max(corr),cbar=True, # vmin和vmax是自定义显示颜色的范围
xticklabels=bands_wavelength, yticklabels=bands_wavelength[::-1])
# plt.title("平方数", fontsize=24) # 设置标题
# plt.xlabel("值", fontsize=14) # 设置x标题
# plt.ylabel("值的平方", fontsize=14) # 设置y轴标题
# plt.savefig('./热力图.png',bbox_inches='tight',dpi=300)
plt.show()
6、一些参数
seaborn.heatmap(data, vmin=None, vmax=None, cmap=None, center=None, robust=False, annot=None, fmt='.2g', annot_kws=None, linewidths=0, linecolor='white', cbar=True, cbar_kws=None, cbar_ax=None, square=False, xticklabels='auto', yticklabels='auto', mask=None, ax=None, **kwargs)
- annot
annotate的缩写,annot默认为False,当annot为True时,在heatmap中每个方格写入数据。
annot_kws,当annot为True时,可设置各个参数,包括大小,颜色,加粗,斜体字等。
import numpy as np
import seaborn as sns
from matplotlib.colors import LogNorm
import matplotlib.ticker as tkr
from matplotlib import pyplot as plt
matrix = np.random.rand(10, 10) / 0.4
vmax = 2
vmin = 0.5
cbar_ticks = [0.5, 0.75, 1, 1.33, 2]
formatter = tkr.ScalarFormatter(useMathText=True)
formatter.set_scientific(False)
log_norm = LogNorm(vmin=vmin, vmax=vmax)
ax_ = sns.heatmap(matrix,square=True, vmax=vmax, vmin=vmin, norm=log_norm,
cbar_kws={"ticks": cbar_ticks, "format": formatter},cmap="jet",
annot=True,annot_kws={'size':6,'weight':'bold', 'color':'w'})
ax_.collections[0].colorbar.ax.yaxis.set_ticks([], minor=True)
plt.show()
- fmt
字符串,可选参数。添加注释时要使用的字符串格式代码。
import numpy as np
import seaborn as sns
from matplotlib.colors import LogNorm
import matplotlib.ticker as tkr
from matplotlib import pyplot as plt
matrix = np.random.rand(10, 10) / 0.4
vmax = 2
vmin = 0.5
cbar_ticks = [0.5, 0.75, 1, 1.33, 2]
formatter = tkr.ScalarFormatter(useMathText=True)
formatter.set_scientific(False)
log_norm = LogNorm(vmin=vmin, vmax=vmax)
ax_ = sns.heatmap(matrix,square=True, vmax=vmax, vmin=vmin, norm=log_norm,
cbar_kws={"ticks": cbar_ticks, "format": formatter},cmap="jet",
annot=True,annot_kws={'size':6,'weight':'bold', 'color':'w'},
fmt='.2f')
ax_.collections[0].colorbar.ax.yaxis.set_ticks([], minor=True)
plt.show()
- ax
matplotlib Axes,可选参数。多子图时使用,用于指定绘制图的坐标轴(即哪个子图),否则使用当前活动的坐标轴。
import numpy as np
import seaborn as sns
from matplotlib.colors import LogNorm
import matplotlib.ticker as tkr
from matplotlib import pyplot as plt
matrix = np.random.rand(10, 10) / 0.4
vmax = 2
vmin = 0.5
cbar_ticks = [0.5, 0.75, 1, 1.33, 2]
formatter = tkr.ScalarFormatter(useMathText=True)
formatter.set_scientific(False)
log_norm = LogNorm(vmin=vmin, vmax=vmax)
fig, ax = plt.subplots(figsize=(8,8),nrows=2)
ax_ = sns.heatmap(matrix, ax=ax[0],square=True, vmax=vmax, vmin=vmin, norm=log_norm,
cbar_kws={"ticks": cbar_ticks, "format": formatter},cmap="jet",
annot=True,annot_kws={'size':6,'weight':'bold', 'color':'w'})
ax_.collections[0].colorbar.ax.yaxis.set_ticks([], minor=True)
ax__ = sns.heatmap(matrix, ax=ax[1],square=True, vmax=vmax, vmin=vmin, norm=log_norm,
cbar_kws={"ticks": cbar_ticks, "format": formatter},cmap="jet",
annot=True,annot_kws={'size':6,'weight':'bold', 'color':'w'})
ax__.collections[0].colorbar.ax.yaxis.set_ticks([], minor=True)
plt.show()
- vmax,vmin, 用于锚定色彩映射的值,否则它们是从数据和其他关键字参数推断出来的。也是图例中最大值和最小值的显示值,没有该参数时默认不显示,同时也是显示的颜色映射。