在Python中,绘制散点图并添加趋势线(通常是线性回归线)、公式、以及相关系数(Pearson Correlation Coefficient)和均方根误差(RMSE)可以通过结合matplotlib
用于绘图,numpy
用于数学运算,scipy
或statsmodels
用于线性回归计算来实现。不过,对于线性回归线和公式的添加,statsmodels
提供了更直接的方式来获取回归方程的参数。
以下是一个完整的示例,展示如何完成这些步骤:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from scipy.stats import pearsonr
from sklearn.metrics import mean_squared_error
import statsmodels.api as sm
# 假设你已经有了一个DataFrame,其中包含两列数据
# 这里我们创建一个示例DataFrame
np.random.seed(0)
x = np.random.rand(100) * 100 # 生成0到100之间的随机数
y = 2 * x + 3 + np.random.randn(100) * 10 # 生成y值,与x线性相关但带有噪声
data = pd.DataFrame({'X': x, 'Y': y})
# 计算相关系数
corr, _ = pearsonr(x, y)
# 计算RMSE(这里假设x是预测值,y是实际值,但在这种情况下,我们只是用它们来演示)
# 注意:在真实应用中,你可能会有不同的预测值
rmse = np.sqrt(mean_squared_error(y, x)) # 但在这种情况下,这没有实际意义,只是演示
# 使用statsmodels进行线性回归,获取趋势线参数
X = sm.add_constant(data['X']) # 添加常数项以拟合截距
model = sm.OLS(data['Y'], X).fit()
intercept, slope = model.params[0], model.params[1] # 截距和斜率
# 绘制散点图
plt.figure(figsize=(10, 6))
plt.scatter(data['X'], data['Y'], color='blue', alpha=0.5, label='Data Points')
# 添加趋势线
x_values = np.array(data['X'])
y_pred = intercept + slope * x_values
plt.plot(x_values, y_pred, color='red', label='Trend Line')
# 添加相关系数和RMSE到图中
plt.text(0.02, 0.95, f'Correlation Coefficient: {corr:.2f}', transform=plt.gca().transAxes, fontsize=12, color='green')
plt.text(0.02, 0.90, f'RMSE (for demonstration): {rmse:.2f}', transform=plt.gca().transAxes, fontsize=12, color='red')
# 添加趋势线方程到图中
plt.text(0.85, 0.05, f'Y = {slope:.2f}X + {intercept:.2f}', transform=plt.gca().transAxes, fontsize=12, color='black', ha='right')
# 设置图例、标题和坐标轴标签
plt.legend()
plt.title('Scatter Plot with Trend Line, Correlation, and RMSE')
plt.xlabel('X')
plt.ylabel('Y')
# 显示图形
plt.grid(True)
plt.show()
注意:
-
在这个示例中,我使用了
numpy
来生成一些模拟数据,但在实际应用中,你应该从文件、数据库或其他数据源中加载数据。 -
我计算了RMSE,但在这个上下文中,它并没有实际意义,因为
x
和y
都是实际观测到的数据,而不是预测值与实际值之间的比较。在回归问题中,你通常会有预测值(由模型根据输入数据计算得出)和实际值(观测到的数据),然后计算RMSE来评估模型的性能。 -
我使用了
statsmodels
来执行线性回归,因为它提供了方便的接口来获取回归模型的参数(如截距和斜率),并且可以直接输出回归统计信息。 -
在添加文本到图形时,我使用了
transform=plt.gca().transAxes
来确保文本的位置是相对于整个图形的轴(axes)进行定位的,这样可以避免在图形缩放时文本位置发生变化。