Matplotlib:复杂数据可视化案例解析
在上一篇文章中,我们介绍了Matplotlib的基础绘图功能和自定义图表的方法。本文将进一步深入,通过实际案例展示如何使用Matplotlib处理和可视化复杂数据集,解决实际数据分析中的挑战性问题。
1. 复杂数据可视化的挑战
处理复杂数据集进行可视化时,通常面临以下挑战:
- 多维数据:需要在二维平面上展示三维或更高维度的数据
- 多变量关系:需要同时展示多个变量之间的复杂关系
- 大数据集:需要有效处理大量数据点而不牺牲可视化质量
- 时间序列:需要展示随时间变化的趋势、季节性和异常
- 空间数据:需要将数据映射到地理坐标上
本文将通过五个案例,展示Matplotlib如何应对这些挑战,创建既美观又富有洞察力的可视化。
2. 案例一:多变量数据分析 - 散点图矩阵
2.1 问题背景
在数据探索阶段,了解多个变量之间的关系至关重要。当变量数量增多时,传统的两两散点图变得繁琐。散点图矩阵(Scatter Plot Matrix)提供了一种高效查看所有变量对之间关系的方法。
2.2 数据准备
我们将使用经典的鸢尾花(Iris)数据集,其包含四个特征变量和一个类别变量:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from sklearn.datasets import load_iris
# 加载鸢尾花数据集
iris = load_iris()
df = pd.DataFrame(data=iris.data, columns=iris.feature_names)
df['species'] = pd.Categorical.from_codes(iris.target, iris.target_names)
# 查看数据前几行
print(df.head())
2.3 实现散点图矩阵
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from sklearn.datasets import load_iris
# 加载数据
iris = load_iris()
df = pd.DataFrame(data=iris.data, columns=iris.feature_names)
df['species'] = pd.Categorical.from_codes(iris.target, iris.target_names)
# 设置绘图风格
plt.style.use('seaborn-whitegrid')
# 创建特征变量列表
features = iris.feature_names
# 创建颜色映射
colors = {
'setosa': 'red', 'versicolor': 'green', 'virginica': 'blue'}
species_colors = df['species'].map(colors)
# 创建散点图矩阵
fig, axes = plt.subplots(len(features), len(features), figsize=(15, 15))
# 调整子图间距
plt.subplots_adjust(wspace=0.3, hspace=0.3)
# 填充散点图矩阵
for i, feature_i in enumerate(features):
for j, feature_j in enumerate(features):
ax = axes[i, j]
# 对角线上绘制直方图
if i == j:
# 为每个物种创建单独的直方图
for species_name in iris.target_names:
species_data = df[df['species'] == species_name]
ax.hist(species_data[feature_i], alpha=0.5, bins=20,
color=colors[species_name], label=species_name)
ax.set_title(feature_i.split(' (')[0], fontsize=12)
# 只在第一个对角图上显示图例
if i == 0:
ax.legend(fontsize=8)
# 非对角线位置绘制散点图
else:
for species_name in iris.target_names:
species_data = df[df['species'] == species_name]
ax.scatter(species_data[feature_j], species_data[feature_i],
c=colors[species_name], label=species_name, alpha=0.6, s=30)
# 添加趋势线 (仅对全体数据)
try:
z = np.polyfit(df[feature_j], df[feature_i], 1)
p = np.poly1d(z)
x_min, x_max = df[feature_j].min(), df[feature_j].max()
ax.plot([x_min, x_max], [p(x_min), p(x_max)], "k--", alpha=0.3)
except:
pass
# 设置轴标签 (只在边缘位置显示)
if j == 0:
ax.set_ylabel(feature_i.split(' (')[0], fontsize=10)
if i == len(features) - 1:
ax.set_xlabel(feature_j.split(' (')[0], fontsize=10)
plt.suptitle('鸢尾花数据集散点图矩阵', fontsize=20)
plt.tight_layout(rect=[0, 0, 1, 0.97]) # 为总标题留出空间
plt.savefig('iris_scatter_matrix.png', dpi=300, bbox_inches='tight')
plt.show()
2.4 结果分析
(执行代码后,将生成实际的图像)
从散点图矩阵中,我们可以观察到:
- 对角线:显示每个特征的分布情况,可以看出不同种类鸢尾花在各特征上的分布差异
- 非对角元素:展示两两特征之间的关系,帮助识别相关性和聚类模式
- 趋势线:提供变量间线性关系的直观指示
- 颜色编码:不同颜色代表不同种类,清晰展示分类边界
这种可视化方法的主要优势是能够一次性展示所有特征对之间的关系,有效发现数据中的模式和异常。
2.5 技术要点与提升
- 性能优化:当数据点增多时,可以使用alpha透明度和降采样技术提高可读性
- 增强表现力:可添加等高线来展示点的密度分布
- 交互性:结合plotly或mpld3库可为散点图矩阵添加交互功能
- 扩展性:可以轻松添加更多统计信息,如相关系数或p值
# 可选:在每个子图右上角添加皮尔逊相关系数
if i != j:
corr = df[[feature_i, feature_j]].corr().iloc[0, 1]
ax.annotate(f'r = {
corr:.2f}', xy=(0.05, 0.95), xycoords='axes fraction',
ha='left', va='top', fontsize=9, bbox=dict(boxstyle='round,pad=0.5',
fc='yellow', alpha=0.3))
3. 案例二:时间序列可视化 - 金融数据分析
3.1 问题背景
金融数据分析是时间序列可视化的典型应用场景。投资者需要同时观察价格走势、交易量、技术指标等多个维度的信息。这要求可视化既信息丰富又易于解读。
3.2 数据准备
我们将使用yfinance库获取股票历史数据,并计算常用技术指标:
import yfinance as yf
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.dates as mdates
from matplotlib.gridspec import GridSpec
from matplotlib.ticker import FuncFormatter
# 获取苹果公司股票数据(过去2年)
ticker = 'AAPL'
start_date = '2021-01-01'
end_date = '2022-12-31'
stock_data = yf.download(ticker, start=start_date, end=end_date)
# 计算技术指标
# 1. 移动平均线 (20天和50天)
stock_data['MA20'] = stock_data['Close'].rolling(window=20).mean()
stock_data['MA50'] = stock_data['Close'].rolling(window=50).mean()
# 2. 相对强弱指标 (RSI)
def calculate_rsi(data, window=14):
delta = data.diff()
gain = delta.where(delta > 0, 0).rolling(window=window).mean()
loss = -delta.where(delta < 0, 0).rolling(window=window).mean()
rs = gain / loss
return 100 - (100 / (1 + rs))
stock_data['RSI'] = calculate_rsi(stock_data['Close'])
# 3. MACD (移动平均收敛/发散)
def calculate_macd(data, fast=12, slow=26, signal=9):
ema_fast = data.ewm(span=fast, adjust=False).mean()
ema_slow = data.ewm(span=slow, adjust=False).mean()
macd_line = ema_fast - ema_slow
signal_line = macd_line.ewm(span=signal, adjust=False).mean()
histogram = macd_line - signal_line
return macd_line, signal_line, histogram
stock_data['MACD'], stock_data['Signal'], stock_data['Histogram'] = calculate_macd(stock_data['Close'])
# 打印数据头部
print(stock_data.head())
3.3 实现复合金融图表
import matplotlib.pyplot as plt
import matplotlib.dates as mdates
from matplotlib.gridspec import GridSpec
from matplotlib.ticker import FuncFormatter
import numpy as np
import pandas as pd
# 设置图表风格
plt.style.use('fivethirtyeight')
# 创建一个复合图表
fig = plt.figure(figsize=(16, 12))
gs = GridSpec(4, 1, height_ratios=[3, 1, 1, 1], hspace=0.1)
# 1. 顶部子图:股价和移动平均线
ax1 = fig.add_subplot(gs[0])
ax1.plot(stock_data.index, stock_data['Close'], label='收盘价', linewidth=2, alpha=0.7)
ax1.plot(stock_data.index, stock_data['MA20'], label='20日均线', linewidth=1.5, alpha=0.8)
ax1.plot(stock_data.index, stock_data['MA50'], label='50日均线', linewidth=1.5, alpha=0.8)
# 设置标题和标签
ax1.set_title(f'{
ticker} 股价走势与技术指标 ({
start_date} 至 {
end_date})', fontsize=16)
ax1.set_ylabel('价格 (USD)', fontsize=12)
ax1.legend(loc='upper left')
ax1.grid(True, alpha=0.3)
# 美化x轴日期显示
ax1.xaxis.set_major_formatter(mdates.DateFormatter('%Y-%m'))
ax1.xaxis.set_major_locator(mdates.MonthLocator(interval=2))
# 2. 第二个子图:成交量
ax2 = fig.add_subplot(gs[1], sharex=ax1)
ax2.bar(stock_data.index, stock_data['Volume'], label='成交量', alpha=0.7, width=1.0)
ax2.set_ylabel('成交量', fontsize=12)
ax2.yaxis.set_major_formatter(FuncFormatter(