python画热力图

python中可使用seaborn.heatmap画热力图,官方文档在这

在分类任务中,也可用于画混淆矩阵:

import numpy as np
import seaborn as sns
import pandas as pd
import matplotlib.pyplot as plt


def confusion_matrix(y_true, y_pred, labels=None):
    n = len(labels)
    labels_dict = {label: i for i, label in enumerate(labels)}
    res = np.zeros([n, n], dtype=np.int32)
    for gold, predict in zip(y_true, y_pred):
        res[labels_dict[gold]][labels_dict[predict]] += 1

    df = pd.DataFrame(res, index=labels, columns=labels)
    sns.heatmap(df, annot=True, fmt='d')
    plt.savefig("./confusion_matrix.jpg")
    plt.show()

y_true = ["cat", "ant", "cat", "cat", "ant", "bird"]  # 真实
y_pred = ["ant", "ant", "cat", "cat", "ant", "cat"]  # 预测
labels = ["ant", "bird", "cat"]

confusion_matrix(y_true, y_pred, labels)

在这里插入图片描述

一些参数的含义:

def heatmap(
    data, *,
    vmin=None, vmax=None, cmap=None, center=None, robust=False,
    annot=None, fmt=".2g", annot_kws=None,
    linewidths=0, linecolor="white",
    cbar=True, cbar_kws=None, cbar_ax=None,
    square=False, xticklabels="auto", yticklabels="auto",
    mask=None, ax=None,
    **kwargs
)
  • 根据data传入的值画出热力图,一般是二维矩阵
  • vmin设置最小值, vmax设置最大值
  • cmap换用不同的颜色
  • center设置中心值
  • annot 是否在方格上写上对应的数字
  • fmt 写入热力图的数据类型,默认为科学计数,d表示整数,.1f表示保留一位小数
  • linewidths 设置方格之间的间隔
  • xticklabels,yticklabels填到横纵坐标的值。可以是bool,填或者不填。可以是int,以什么间隔填,可以是list

例子:

import numpy as np
np.random.seed(0)
import seaborn as sns
sns.set_theme()
uniform_data = np.random.rand(10, 12)
ax = sns.heatmap(uniform_data)

在这里插入图片描述

将最后一行改为,设置最大值和最小值:

ax = sns.heatmap(uniform_data, vmin=0, vmax=1)

在这里插入图片描述

设置中心值:

normal_data = np.random.randn(10, 12)
ax = sns.heatmap(normal_data, center=0)

在这里插入图片描述

从文件中获取数据,并画图给出有意义的横纵坐标:

flights = sns.load_dataset("flights")
flights = flights.pivot("month", "year", "passengers")
ax = sns.heatmap(flights)

在这里插入图片描述

将passengers对应的人数标出:

ax = sns.heatmap(flights, annot=True, fmt="d")

在这里插入图片描述
设置方格之间的间隔:

ax = sns.heatmap(flights, linewidths=.5)

在这里插入图片描述
设置使用不同的颜色:

ax = sns.heatmap(flights, cmap="YlGnBu")

在这里插入图片描述

以某个具体的数据为中心:

ax = sns.heatmap(flights, center=flights.loc["Jan", 1955])

在这里插入图片描述

自动填充坐标值:

data = np.random.randn(50, 20)
ax = sns.heatmap(data, xticklabels=2, yticklabels=False)

在这里插入图片描述

不画右边的热度条:

ax = sns.heatmap(flights, cbar=False)

在这里插入图片描述

  • 26
    点赞
  • 192
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 2
    评论
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

旺旺棒棒冰

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值