【科研分享】Matplotlib 绘制热力图(heatmap)进行实验结果分析

Matplotlib 绘制热力图(heatmap)进行实验结果分析


写论文的时候又碰到了新的需求,为了呈现实验结果,这次需要做一个分析两个超参的图,搜了半天发现还是热力图最合适,但是在各处看了很多篇文章发现讲的要么不是太仔细,要么不太符合我的需求,于是决定投稿完事就赶紧写一篇文章记录一下,本文主要讲了一下如何使用matplotlib绘制热力图,并且给出了一些代码实验结果,分享给大家,希望能有用。

预备工作

前期环境配置

老样子,先导包:pycharm-2021.2.2, python-3.7, matplotlib-3.2.2, numpy-1.20.2 我这里就是用的早期配置的环境,顺便把自己的版本号发上来了,推测如果你用更高版本应该也没事:

import numpy as np
import matplotlib.pyplot as plt

准备数据

比如,我想要展示的是Hit Ratio @ 5这个评价指标,那么我首先需要定义一个np.array,将我想要展示的数据以矩阵的形式存进来,我在此举的例子是绘制一个5X5的热力图,所以我的输入应该是一个5x5的二维array:

HR_5 = np.array([[75.22, 76.34, 75.31, 78.03, 76.57],
                 [80.52, 82.93, 81.33, 83.97, 83.41],
                 [78.70, 80.41, 79.12, 82.91, 81.44],
                 [80.04, 82.66, 81.03, 83.87, 83.28],
                 [78.12, 79.26, 79.21, 80.14, 80.52]])

同样的,我也可以先预定义好我的x轴和y轴的显示样式:

# 坐标轴-y (行标)
tag_y = ["$α$=1", "$α$=2", "$α$=3", "$α$=4", "$α$=5"]
# 坐标轴-x (列标)
tag_x = ["$β$=1", "$β$=2", "$β$=3", "$β$=4", "$β$=5"]

这里解释一下,加$符号的意思表示取斜体,为了和我paper中的图上的斜体符号保持一致,这个写法跟Latex中要求的是一致的,第一次在pycharm里写出来的时候,我也很惊喜。

开始画图

我们就来看一种简单的情况吧,假设我们只需要画一张图(通常为了展示实验结果仅绘制一张图不可能的,但现在就假设我只给他看一种指标(即HR@5 — 一种推荐系统研究中常用的指标)):

fig, ax = plt.subplots()

接下来使用ax变量来对图进行进一步调整:

ax.set_xticks(np.arange(len(tag_x)))
# 设置x轴刻度间隔,参数为x轴刻度长度,其实也可以写作np.arange(0, 5, 1),目的就是提供5个刻度
ax.set_yticks(np.arange(len(tag_y)))
# 设置y轴刻度间隔
ax.set_xticklabels(tag_x)
# 设置x轴标签
ax.set_yticklabels(tag_y)

这里指需要注意刻度的个数必须符合我们想要绘制的热力图矩阵大小,也就是我们输入矩阵的维度。调整好了之后就可以绘制热力图了:

plt.imshow(HR_5, cmap='coolwarm', origin='upper', aspect="auto")

其中,第一个参数要传入的是我们的输入矩阵(在本例子中也就是我们准备的HR_5矩阵),它将代表要显示的格子,cmap参数表示的是你想要使用的热力图颜色风格,我这里使用了一种对比度较高的coolwarm风格,大家可以根据自己的喜好更换这个参数。origin表示的是坐标(0, 0)点的位置,如果选择"upper"则意味着,坐标开始的位置是图像的左上角,如果设置为"lower"则坐标原点在左下角(但因为我们绘制的是实验结果展示图,所以通常从左上角开始显示)。最后的参数aspect如果是‘auto’,则代表图像的长宽比将和坐标系进行自动匹配。

这里给出两个引用,第一个引用方便大家调色:

Matplotlib的imshow()函数颜色映射(cmap的取值)

第二个引用方便大家理解plt.imshow函数中各个参数的意思:

Matplotlib之plt.imshow()方法详解

然后这时候再执行:

plt.show()

就能看见我们的热力图了:
热力图雏形
然后我们为它添加热力标尺:

plt.colorbar()

带标尺的热力图
再为它添加X轴和Y轴的标签(这里额外定义了一个字体格式,为了使坐标轴标签显得不那么小):

default_font = {'family': 'Times New Roman', 'weight': 'bold', 'size': 14}
plt.xlabel('Values of the parameter $β$\n(a) HR@5 (%)\n', default_font)
plt.ylabel('Values of the parameter $α$', default_font)

注意,这几行代码必须放到plt.show()前面才会生效。
问题图片1
这时候我们发现有个问题,下面的坐标轴描述没有显示全,所以我们回到前面去调整图像的大小和边距:

fig, ax = plt.subplots(1, 1, figsize=(5, 5), dpi=300)
plt.subplots_adjust(top=0.98, bottom=0.15, left=0.15, right=0.99, hspace=0,
                    wspace=0)

这里解释一下,第一行我们在subplots()里加的参数:前两个参数表示我们目前子图就一个(1行1列),第二个参数figsize表示所绘画布的大小5*5, dpi参数越高则图像放大后失真程度越小。
而第二行代码,则是对话不得边距进行调整,top越接近于1则图像整体越靠近顶部,bottom越接近于0,则图像整体越接近于底部,left和right也遵循相似的原则,我们经过上述两行代码可以得到下面这个样子:
调整好的图
这已经非常接近我们想要的结果了,在此基础上还可以给图像增加额外的修饰,如果你想的话。所以,我再次给出一些你可能感兴趣的修改方案:
1.增加标题

plt.subplots_adjust(top=0.94, bottom=0.15, left=0.15, right=0.99, hspace=0,
                    wspace=0)
# 由于增加了标题导致重新调整了一下上边距为0.94
ax.set_title("Fig. 1: Happer-parameter test of $α$ and $β$ on A dataset.")

添加标题
2. 增加文字嵌入

for i in range(len(tag_y)):
    for j in range(len(tag_x)):
        text = ax.text(j, i, HR_5[i, j],
                       ha="center", va="center", color="black", fontweight="bold")

其中,j,i:表示坐标位置上的值,参数ha有三个值可选:right、center和left,分别对应着不同的对齐方式。而va有四个值可以选择:‘top’、 ‘bottom’、 'center’和 ‘baseline’,而参数color则代表设置颜色,fontweight代表字体形式。注意这块代码需要放在plt.imshow()函数之前,才可以得到如下结果:
嵌入
3. 以一定角度翻转x轴刻度

plt.subplots_adjust(top=0.94, bottom=0.22, left=0.15, right=0.99, hspace=0,
                    wspace=0)
# 由于偏转x轴刻度会导致垂直空间上更多地占用,所以重新调整了下边距为0.22
plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor")

注意,这行代码需要放在设置刻度的代码之后,才能得到如下结果:
成品图

完整代码

随后,附上全部源代码:

# @Author: Jinyu Zhang
# @Time: 2022/4/30 15:07
# @E-mail: JinyuZ1996@outlook.com
import numpy as np
import matplotlib
import matplotlib.pyplot as plt

default_font = {'family': 'Times New Roman', 'weight': 'bold', 'size': 14}
tag_y = ["$α$=1", "$α$=2", "$α$=3", "$α$=4", "$α$=5"]
tag_x = ["$β$=1", "$β$=2", "$β$=3", "$β$=4", "$β$=5"]

HR_5 = np.array([[75.22, 76.34, 75.31, 78.03, 76.57],
                 [80.52, 82.93, 81.33, 83.97, 83.41],
                 [78.70, 80.41, 79.12, 82.91, 81.44],
                 [80.04, 82.66, 81.03, 83.87, 83.28],
                 [78.12, 79.26, 79.21, 80.14, 80.52]])

fig, ax = plt.subplots(1, 1, figsize=(5, 5), dpi=300)
plt.subplots_adjust(top=0.94, bottom=0.22, left=0.15, right=0.99, hspace=0,
                    wspace=0)

ax.set_xticks(np.arange(0, 5, 1))
ax.set_yticks(np.arange(len(tag_y)))
ax.set_xticklabels(tag_x)
ax.set_yticklabels(tag_y)
plt.setp(ax.get_xticklabels(), rotation=45, ha="right",
         rotation_mode="anchor")

for i in range(len(tag_y)):
    for j in range(len(tag_x)):
        text = ax.text(j, i, HR_5[i, j],
                       ha="center", va="center", color="black", fontweight="bold")

plt.imshow(HR_5, cmap='coolwarm', origin='upper', aspect="auto")
plt.colorbar()
plt.xlabel('Values of the parameter $β$\n(a) HR@5 (%)\n', default_font)
plt.ylabel('Values of the parameter $α$', default_font)
ax.set_title("Fig. 1: Happer-parameter test of $α$ and $β$ on A dataset.", default_font)

plt.show()

更多

本文只讲述了简单的绘制一幅热力图的情况,但假设我想绘制多个评价指标的热力图,该怎么办呢?可以参考我先前的一篇文章来设置多子图实验展示,其实原理很简单,就是将subplot里的参数进行修改,然后再分别对每一个子图准备数据,然后画出来就可以了,参考链接如下:

【科研分享】Matplotlib 绘制多子图(subplot)进行实验结果分析

Reference

在学习热力图绘制过程中以及本文的撰写中,都参考了很多前辈的文章,在此给出引用,站在巨人的肩膀上学习:

使用 matplotlib 绘制热力图 by mr_songw

matplotlib绘制热力图 by jin_tmac

matplotlib绘制热力图,并显示数值 by 苏里

  • 23
    点赞
  • 115
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

JinyuZ1996

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值