做毕设,需要用PySide2去显示一个matplotlib图像,但是又得显示到指定标签页,网上基本上都是用canvas,但是又没有指定到某确定位置的方法,这对本菜鸡实在为难。
看了不少Stack Overflow答案都没有对口的,最后还是通过参考《Pyside2中嵌入Matplotlib的绘图》做了出来。
更新,下面这个绘制Precision-Recall曲线的例子对初学者可能有点难,新写了一个文章
一、目的:在PR标签页去绘制一个模型的Precision-Recall的曲线
用到的控件及命名为:
二、定义一个类,继承FigureCanvas
import matplotlib
from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg as FigureCanvas
import matplotlib.pyplot as plt
matplotlib.use("Qt5Agg") # 声明使用QT5
class MyFigureCanvas(FigureCanvas):
"""
通过继承FigureCanvas类,使得该类既是一个PyQt5的Qwidget,又是一个matplotlib的FigureCanvas,这是连接pyqt5与matplotlib的关键
"""
def __init__(self, parent=None, width=10, height=5, dpi=100):
# 创建一个Figure
self.fig = plt.Figure(figsize=(width, height), dpi=dpi, tight_layout=True) # tight_layout: 用于去除画图时两边的空白
FigureCanvas.__init__(self, self.fig) # 初始化父类
self.setParent(parent)
self.axes = self.fig.add_subplot(111) # 添加子图
self.axes.spines['top'].set_visible(False) # 去掉绘图时上面的横线
self.axes.spines['right'].set_visible(False) # 去掉绘图时右面的横线
注意:导包时,要将 PySide2 和 UI 的包 放在 matplotlib 相关包 的 前面
三、定义一个共享类,方便保存绘制的图像
class SI:
n_classes = 5
Y_test = None # one-hot后的 Y_pred
Y_pred = None # one-hot后的 Y_test
y_scores = None
# 存储图像
figPR = None
四、窗口,实现button点按绘制显示与保存
下面程序用到Qt Designer设计好的Win_Main窗口中的ui控件都叫self.ui.name,如self.ui.btn_show_PRgraph
import os
from time import localtime, strftime
from sklearn.metrics import precision_recall_curve, roc_curve, roc_auc_score, auc
from sklearn.metrics import average_precision_score
from PySide2.QtWidgets import QApplication, QMessageBox, QFileDialog, QGraphicsScene
import matplotlib
from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg as FigureCanvas
import matplotlib.pyplot as plt
matplotlib.use("Qt5Agg") # 声明使用QT5
class Win_Main:
# 1、初始化
def __init__(self):
self.ui = QUiLoader().load("main.ui") # 注意是对ui实例化
... # 此处省略其他标签页的初始化
# 第3个Tab——PR
self.graph_content_PR = MyFigureCanvas(width=self.ui.PRgraph.width() / 101,
height=self.ui.PRgraph.height() / 101
)
self.graphic_scene_PR = QGraphicsScene() # 创建一个QGraphicsScene
self.ui.btn_show_PRgraph.clicked.connect(self.plot_PR) # 生成ROC
self.ui.btn_save_PRgraph.clicked.connect(self.save_PR) # 保存ROC
# 2、绘制PR曲线
def plot_PR(self):
self.graph_content_PR.axes.clear() # 每次绘制时需要清空之前的图像
y_test = SI.Y_test # 这个得你自己去把Y_test到SI类中保存下
y_score = SI.y_scores # 这个得你自己去把y_scores到SI类中保存下
n_classes = SI.n_classes # 这个得你自己去把n_classes到SI类中保存下
# Plot the macro-averaged Precision-Recall curve
# For each class
precision = dict()
recall = dict()
average_precision = dict()
for i in range(n_classes):
precision[i], recall[i], _ = precision_recall_curve(y_test[:, i],
y_score[:, i])
average_precision[i] = average_precision_score(y_test[:, i], y_score[:, i])
# A "macro-average": quantifying score on all classes jointly
precision["macro"], recall["macro"], _ = precision_recall_curve(y_test.ravel(),
y_score.ravel())
average_precision["macro"] = average_precision_score(y_test, y_score,
average="macro")
print('Average precision score, macro-averaged over all classes: {0:0.2f}'
.format(average_precision["macro"]))
# plt.figure()
# self.graph_content_PR.axes.plot(recall['macro'], precision['macro'], where='post') # 可以用plot直接代替(去掉where参数)
self.graph_content_PR.axes.plot(recall['macro'], precision['macro'])
self.graph_content_PR.axes.set_xlabel('Recall')
self.graph_content_PR.axes.set_ylabel('Precision')
self.graph_content_PR.axes.set_ylim([0.0, 1.05])
self.graph_content_PR.axes.set_xlim([0.0, 1.0])
self.graph_content_PR.axes.set_title(
'Average precision score, macro-averaged over all classes: AP={0:0.3f}'
.format(average_precision["macro"]))
SI.figPR = self.graph_content_PR.fig
# 加载的图形(FigureCanvas)不能直接放到graphicview控件中,必须先放到graphicScene,然后再把graphicscene放到graphicview中
# 即 FigureCanvas -> graphicScene -> graphicview
self.graphic_scene_PR.addWidget(self.graph_content_PR) # 把图形放到QGraphicsScene中,注意:图形是作为一个QWidget放到放到QGraphicsScene中的
self.ui.PRgraph.setScene(self.graphic_scene_PR) # 把QGraphicsScene放入QGraphicsView
self.ui.PRgraph.show() # 调用show方法呈现图形
# 3、保存PR曲线
def save_PR(self):
# 给导出名加个时间戳
time_str = strftime("%Y_%m_%d_%H_%M_%S", localtime())
toPath = f"PR_curve_{time_str}.png" # _{SI.trainFileNames[0]}
if os.path.exists(toPath):
pass
else:
SI.figPR.savefig(toPath)
# img.imsave(toPath)
QMessageBox.information(
self.ui,
"保存成功!",
"已经保存到本文件夹"
)
最后调用即可
app = QApplication([])
SI.loginWin = Win_Main()
SI.loginWin.ui.show()
app.exec_()