使用Matplotlib和Seaborn绘制簇状柱状图
Matplotlib提供了丰富的绘图功能和高度自定义的选项,而Seaborn则基于Matplotlib,提供了更高级别的接口,使得绘制美观的图表变得更为简单。
Seaborn风格与配色方案
Seaborn的风格
Seaborn提供了五种主要的风格,分别是:darkgrid
、whitegrid
、dark
、white
和ticks
。
darkgrid
:在暗色背景上绘制网格线,适合用于大多数图表,尤其是那些需要强调数据的情况。whitegrid
:在白色背景上绘制网格线,同样适用于大多数图表,但更适合在浅色背景下使用。dark
:暗色背景且没有网格线,适合创建科学报告或黑暗主题的图表。white
:白色背景且没有网格线,适合在明亮的环境中使用,尤其是在演示或打印时。ticks
:在白色背景上绘制刻度线,适合创建包含刻度标记的图表,特别是那些需要强调刻度的情况。
import seaborn as sns
# 设置 seaborn 风格
sns.set_style("darkgrid")
Seaborn的配色方案
1. 预定义的调色板
除了风格外,Seaborn还提供了多种配色方案,用于在图表中区分不同的数据类别。以下是一些常用的配色方案:
deep
:一组较深的颜色,用于区分不同的类别,特别适用于分类数据。muted
:一组柔和的颜色,不太鲜艳,适用于中等大小的数据集。bright
:一组明亮的颜色,非常鲜艳,适用于需要强调不同类别的情况。pastel
:一组柔和的颜色,比较淡雅,适用于视觉上温和的图表。dark
:一组深色的颜色,适用于需要强调数据的情况。colorblind
:适用于色盲的颜色方案,能够保证即使是色盲患者也能够清晰地区分不同的类别。husl
:一组从黄色到红色再到蓝色的连续颜色,适用于需要区分多个类别且不希望颜色过于刺眼的情况。Paired
:一组成对的颜色,每个颜色都与另一个颜色形成对比,适用于需要将数据按照两两配对进行比较的情况。Set1
、Set2
、Set3
:三组离散的颜色,每组颜色都有不同的亮度和饱和度,适用于区分多个类别的情况。tab10
、tab20
、tab20b
、tab20c
:四组离散的颜色,每组颜色都有不同的亮度和饱和度,适用于区分多个类别的情况,特别是在需要更多的颜色选择时。
# 设置 seaborn 配色方案
sns.set_palette("tab20")
2. 定制调色板
除了预定义调色板外,Seaborn还允许用户创建自己的定制调色板。可以使用sns.color_palette()函数来创建定制调色板,并将其传递给set_palette()函数。
custom_palette = sns.color_palette(["#fe9900", "#ffff00", "#0099cb"]) # 使用RGB颜色值创建调色板
# custom_palette = sns.color_palette(["red", "green", "blue"]) # 使用颜色名称创建调色板
sns.set_palette(custom_palette)
3. 使用颜色映射
在某些情况下,可能希望根据数据的值来自动选择颜色。这时可以使用Seaborn提供的颜色映射功能。例如,在绘制热图或散点图时,可以使用cmap
参数来指定颜色映射。Seaborn支持许多不同的颜色映射,例如"viridis"、“inferno”、"coolwarm"等等。
sns.heatmap(data, cmap="viridis")
在Seaborn中,可以通过设置参数来调整颜色的饱和度、透明度和字体大小等。下面介绍如何使用Seaborn来实现这些设置:
其他设置
1. 颜色饱和度
Seaborn中的set_palette()
函数可以接受一个saturation
参数,用于调整颜色的饱和度。饱和度参数的取值范围为0到1,其中0表示完全不饱和(灰度色),1表示完全饱和(原始颜色)。例如:
sns.set_palette("deep", saturation=0.8)
这将使用"deep"调色板,并将其饱和度设置为0.8。
2. 透明度
在Seaborn中,可以通过在绘图函数中传递alpha
参数来设置图表元素的透明度。透明度参数的取值范围为0到1,其中0表示完全透明,1表示完全不透明。例如:
# 示例数据
data = pd.DataFrame({'x': [1, 2, 3, 4, 5], 'y': [2, 4, 6, 8, 10]})
# 绘制散点图,并设置透明度
sns.barplot(x="x", y="y", data=data, alpha=0.5)
这将绘制一个柱状图,并将其透明度设置为0.5。
3. 字体大小
Seaborn中的字体大小可以通过设置Matplotlib的参数来实现。可以使用rc
方法来设置Matplotlib的参数,从而调整字体大小。例如:
sns.set_context("paper", rc={"font.size": 13})
这将将绘图上下文设置为"paper",并将字体大小设置为13。
Seaborn提供了几种不同的绘图上下文,包括paper
、notebook
、talk
和poster
,每种上下文对应不同的字体大小。
示例代码
import os # 导入用于操作操作系统功能的 os 模块
import pandas as pd # 导入用于数据处理和分析的 Pandas 库
import matplotlib.pyplot as plt # 导入用于绘图的 Matplotlib 库
import numpy as np # 导入用于数值计算的 Numpy 库
import seaborn as sns # 导入用于数据可视化的 Seaborn 库
# 设置 seaborn 风格为 darkgrid
sns.set_style("darkgrid")
# 设置 seaborn 配色方案为 tab20
sns.set_palette("tab20")
# 设置 seaborn 字体大小
sns.set_context(rc={"font.size": 13})
# 定义数据集名称列表
dataset_names = ['ETTm1', 'ETTm2', 'ETTh1', 'AF.PA']
# 定义时间步列表
time_steps = [24, 48, 96, 288, 672]
times = ['24', '48', '96', '288', '672']
etth1_time_steps = [24, 48, 168, 336, 720]
etth1_times = ['24', '48', '168', '336', '720']
# 定义指标列表
metrics = ['MSE', 'MAE']
# 定义模型名称列表
models = ['VATSRL', 'TS2Vec', 'TS-TCC', 'TST', 'TNC', 'CRT']
# 定义颜色列表
colors = []
# 定义绘制柱状图的函数
def plot(dataset):
# 遍历数据集
for dataset_name, datas in dataset.items():
# 遍历指标
for metric in metrics:
# 初始化一个空列表来保存数据,用于绘制柱状图
data = datas
bars_data = []
# 根据数据集名称选择时间步和时间列表
if dataset_name == "ETTh1":
choose = etth1_time_steps
chooses = etth1_times
else:
choose = time_steps
chooses = times
# 遍历时间步
for lookhead in choose:
lookhead_data = []
# 遍历模型
for model in models:
# 如果数据存在,则添加到列表中;否则添加 NaN
if lookhead in data[model][metric]:
lookhead_data.append(data[model][metric][lookhead])
else:
lookhead_data.append(np.nan)
bars_data.append(lookhead_data)
# 创建绘图对象
fig, ax = plt.subplots()
width = 0.85 # 每个柱子的宽度
index = np.arange(len(models)) * len(time_steps) # 每个模型的起始位置
bar_width = width * len(time_steps) # 每个模型的总宽度
# 为每个模型绘制柱状图
for i, data in enumerate(bars_data):
data = np.array(data)
# ax.bar(index + i * width, data, width, label=chooses[i], alpha=0.75)
ax.bar(index + i * width, data, width, alpha=0.75)
# 设置图表的标题和轴标签
ax.set_title(f'{dataset_name} {metric} Scores')
ax.set_ylabel(metric)
ax.set_xticks(index + (bar_width - width) / 2)
ax.set_xticklabels(models, rotation=45, ha='right')
ax.set_xlim(index[0] - bar_width / 2, index[-1] + bar_width * 1.3)
# 添加图例,看需要添加
# ax.legend()
# 显示网格
ax.grid(axis='y', alpha=0.5)
# 保存图表为 PDF 文件
filename = './result/Prediction'
if not os.path.exists(filename):
os.makedirs(filename)
plt.savefig(f'{filename}/{dataset_name}_{metric}.pdf', bbox_inches='tight')
# bbox_inches='tight', 自动调整页面大小以适应图形的大小,以防止图形的一部分超出页面边界
# 显示图表
plt.tight_layout()
plt.show()
# 调用绘制函数
plot(dataset)
由于绘图使用的数据是我们实验过程中得到的实际数据,论文还未发表,暂时不方便提供。
csv表格中的数据经过读取后的格式大概如下,可以自行准备数据。
dataset = {
'ETTm1': {
'VATSRL': {'MSE': {24: value, 48: value, 96: , 288: , 672:}, 'MAE': {24: , 48: , 96: , 288: , 672: }},
'TS2Vec': {...},
'TS-TCC': {...},
'TST': {...},
'TNC': {...},
'CRT': {...}
},
'ETTm2': {...},
'ETTh1': {'VATSRL': {'MSE': {24: value, 48: value, 168: , 336: , 720: }, ...},
...
}
最后绘制出的图像效果