Seaborn:统计数据可视化最佳实践
在前两篇文章中,我们介绍了Matplotlib的基础功能和复杂数据可视化案例。本文将聚焦于Seaborn库,这是一个建立在Matplotlib基础上的高级统计可视化库,提供了更简洁的API和更美观的默认样式,特别适合统计数据分析和可视化。
1. Seaborn简介
1.1 什么是Seaborn?
Seaborn是一个基于Matplotlib的Python数据可视化库,专为统计绘图而设计。它提供了一套高级接口,使创建复杂且信息丰富的统计图表变得简单。Seaborn的核心优势包括:
- 优雅的默认样式:相比Matplotlib的基础风格,Seaborn提供了更现代、更美观的默认设置
- 内置统计功能:直接集成统计模型,可一步完成数据计算和可视化
- 无缝处理数据框:与pandas数据结构紧密集成,简化数据处理流程
- 多变量关系探索:提供专门的工具来探索和可视化多维数据关系
- 智能调色板:为分类、顺序和连续数据提供合适的配色方案
1.2 Seaborn与Matplotlib的关系
Seaborn建立在Matplotlib的基础上,可以理解为Matplotlib的"高级封装":
- Seaborn使用而非替代Matplotlib
- Seaborn图表本质上是Matplotlib图表,可以通过Matplotlib的API进一步自定义
- Seaborn专注于统计可视化,而Matplotlib支持更广泛的绘图类型
1.3 安装与设置
# 使用pip安装
pip install seaborn
# 或使用conda安装
conda install seaborn -c conda-forge
# 基本导入
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
2. Seaborn的核心绘图类别
Seaborn的绘图函数可以分为6个主要类别,每种类别针对不同的数据分析需求:
2.1 关系型图表(Relational plots)
用于可视化两个变量之间的关系,主要包括:
- 散点图:
scatterplot()
- 线图:
lineplot()
- 综合关系图:
relplot()
# 加载示例数据集
tips = sns.load_dataset("tips")
# 创建散点图
plt.figure(figsize=(10, 6))
sns.scatterplot(x="total_bill", y="tip", hue="time", size="size", data=tips)
plt.title("账单总额与小费关系图")
plt.show()
# 高级关系图
g = sns.relplot(
data=tips,
x="total_bill", y="tip",
col="time", hue="day", size="size",
palette="crest", sizes=(10, 100)
)
g.set_titles("用餐时间: {col_name}")
g.set_axis_labels("账单总额", "小费")
plt.show()
2.2 分布图(Distribution plots)
用于可视化一个或多个变量的分布情况:
- 直方图:
histplot()
- 核密度估计图:
kdeplot()
- 经验累积分布函数:
ecdfplot()
- 综合分布图:
displot()
# 单变量分布
plt.figure(figsize=(12, 5))
plt.subplot(1, 2, 1)
sns.histplot(data=tips, x="total_bill", kde=True)
plt.title("账单金额分布图")
plt.subplot(1, 2, 2)
sns.kdeplot(data=tips, x="total_bill", hue="time", fill=True, common_norm=False)
plt.title("不同用餐时间的账单分布")
plt.tight_layout()
plt.show()
# 二元分布
sns.displot(
data=tips,
x="total_bill", y="tip",
kind="kde",
height=7
)
plt.title("账单与小费的二维密度图")
plt.show()
2.3 分类图(Categorical plots)
用于比较不同类别之间的分布:
- 箱线图:
boxplot()
- 小提琴图:
violinplot()
- 散点条形图:
stripplot()
- 蜂群图:
swarmplot()
- 计数图:
countplot()
- 条形图:
barplot()
- 综合分类图:
catplot()
# 箱线图与小提琴图对比
plt.figure(figsize=(14, 6))
plt.subplot(1, 2, 1)
sns.boxplot(x="day", y="total_bill", hue="time", data=tips)
plt.title("箱线图:不同日期的账单分布")
plt.subplot(1, 2, 2)
sns.violinplot(x="day", y="total_bill", hue="time", data=tips, split=True, inner="quart")
plt.title("小提琴图:不同日期的账单分布")
plt.tight_layout()
plt.show()
# 组合分类图
sns.catplot(
data=tips, kind="swarm",
x="day", y="total_bill", hue="sex",
height=6, aspect=1.5
)
plt.title("蜂群图:每天的账单分布与性别差异")
plt.show()
2.4 回归图(Regression plots)
用于可视化变量之间的关系并拟合回归模型:
- 简单回归图:
regplot()
- 分面回归图:
lmplot()
- 残差图:
residplot()
# 简单线性回归
plt.figure(figsize=(10, 6))
sns.regplot(x="total_bill", y="tip", data=tips, scatter_kws={
"alpha": 0.5})
plt.title("账单与小费的线性关系")
plt.show()
# 分组回归
sns.lmplot(
data=tips,
x="total_bill", y="tip",
col="time", row="sex",
height=4
)
plt.suptitle("不同性别和用餐时间的消费-小费关系", y=1.05)
plt.tight_layout()
plt.show()
2.5 矩阵图(Matrix plots)
用于可视化多维数据的矩阵表示:
- 热力图:
heatmap()
- 聚类图:
clustermap()
# 创建相关矩阵
corr = tips.corr()
# 热力图
plt.figure(figsize=(10, 8))
sns.heatmap(
corr,
annot=True, # 显示数值
cmap="coolwarm", # 使用蓝红色映射
vmin=-1, vmax=1, # 设置范围
square=True, # 方形单元格
linewidths=0.5, # 网格线宽度
)
plt.title("变量间相关系数热力图")
plt.tight_layout()
plt.show()
# 聚类热力图
flights = sns.load_dataset("flights")
flights_pivot = flights.pivot("month", "year", "passengers")
plt.figure(figsize=(12, 10))
sns.clustermap(
flights_pivot,
cmap="YlGnBu",
standard_scale=1, # 标准化
method="average", # 聚类方法
figsize=(12, 10)
)
plt.title("航班乘客数聚类热图", y=1.02)
plt.show()
2.6 多面网格图(Facet Grid plots)
用于创建条件关系的网格可视化:
- 因子网格:
FacetGrid()
- 配对网格:
PairGrid()
- 联合网格:
JointGrid()
# FacetGrid示例
g = sns.FacetGrid(tips, col="time", row="smoker", height=4)
g.map_dataframe(sns.scatterplot, x="total_bill", y="tip")
g.add_legend()
plt.tight_layout()
plt.show()
# PairGrid示例
g = sns.PairGrid(tips, hue="time", corner=True)
g.map_lower(sns.scatterplot, alpha=0.7)
g.map_diag(sns.histplot)
g.add_legend()
plt.suptitle("参数之间的配对关系", y=1.02)
plt.show()
# JointGrid示例
g = sns.JointGrid(data=tips, x="total_bill", y="tip", height=7)
g.plot_joint(sns.scatterplot, alp