代码如下
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
from pylab import mpl
import pandas as pd
def scatter(dataset):
mpl.rcParams['font.sans-serif'] = ['STZhongsong']
mpl.rcParams['axes.unicode_minus'] = False
x1 = [0.0, 0.5, 1.0, 1.5, 2.0]
y1 = [0.0, 0.5, 1.0, 1.5, 2.0]
x2 = [-0.2, 0.2, 0.7, 1.2, 2.0]
y2 = [0.0, 0.4, 0.9, 1.4, 2.2]
x3 = [0.0, 0.5, 1.0, 1.5, 2.0]
y3 = [-0.2, 0.3, 0.8, 1.3, 1.8]
y_real = dataset['y_test']
y_rf = dataset['yrf']
y_krr = dataset['ykrr']
plt.figure(dpi=150)
plt.axis([-0.2, 2, 0, 2])
plt.xticks([i * 0.2 for i in range(0, 10)])
plt.yticks([i * 0.2 for i in range(0, 10)])
plt.plot(x1, y1, color="b", linestyle="-", linewidth=1)
plt.plot(x2, y2, color="y", linestyle="dotted", linewidth=1)
plt.plot(x3, y3, color="y", linestyle="dotted", linewidth=1)
type1 = plt.scatter(y_rf, y_real, c='m', marker='^')
type2 = plt.scatter(y_krr, y_real, c='g', marker='o')
plt.rcParams.update({'font.size': 10})
plt.legend((type1, type2), ("RF", "KRR"),loc='upper left')
plt.ylabel("预测值")
plt.xlabel("真实值")
plt.show()
if __name__ == '__main__':
filename = "./model_predictions/example.csv"
df = pd.read_csv(filename)
ytest_ypred = pd.DataFrame({'y_test': df['real'],
'yrf': df['rf'],'ykrr': df['krr']})
scatter(ytest_ypred)
example.csv中样本格式如下图
具体作图效果