# @File: plot_window.py
# @Author: A07567
# @Date: 2021/4/13 10:11
# Description:
import random
from PyQt5 import QtWebEngineWidgets, QtCore, QtGui, QtWidgets
import os
import sys
import tempfile
from PyQt5.QtWebEngineWidgets import QWebEngineDownloadItem
from plotly.io import to_html
import plotly.graph_objs as go
class PlotlyViewer(QtWebEngineWidgets.QWebEngineView):
def __init__(self, parent=None, fig=None):
super().__init__(parent=parent)
self.page().profile().downloadRequested.connect(self.on_downloadRequested)
self.settings().setAttribute(self.settings().ShowScrollBars, False)
self.settings().setAttribute(QtWebEngineWidgets.QWebEngineSettings.WebGLEnabled, True)
self.temp_file = tempfile.NamedTemporaryFile(mode="w", suffix=".html", delete=False)
self.set_figure(fig)
def set_figure(self, fig=None):
self.temp_file.seek(0)
if fig is None:
fig = go.Figure()
# fig.show()
fig.update_xaxes(showspikes=True)
fig.update_yaxes(showspikes=True)
fig.update_layout(hovermode='closest',
hoverlabel=dict(bgcolor="white", font_size=10, font_family="Rockwell"),
)
html = to_html(fig, config={"responsive": True, 'scrollZoom': True})
html += "\n<style>body{margin: 0;}" \
"\n.plot-container,.main-svg,.svg-container{width:100% !important; height:100% !important;}</style>"
self.temp_file.write(html)
self.temp_file.truncate()
self.temp_file.seek(0)
self.load(QtCore.QUrl.fromLocalFile(self.temp_file.name))
def closeEvent(self, event: QtGui.QCloseEvent) -> None:
self.temp_file.close()
os.unlink(self.temp_file.name)
super().closeEvent(event)
def sizeHint(self) -> QtCore.QSize:
return QtCore.QSize(800, 600)
def on_downloadRequested(self, download: QWebEngineDownloadItem):
dialog = QtWidgets.QFileDialog()
dialog.setDefaultSuffix(".png")
path, _ = dialog.getSaveFileName(self, "Save File", os.path.join(os.getcwd(), "newplot.png"), "*.png")
if path:
download.setPath(path)
download.accept()
class Fig:
def __init__(self, multiple=True):
self.multiple = multiple
self.data = {"speed": [1, 7, 3, 9, 5], "temperature": [110, 80, 130, 170, 200]}
self.Y_name = []
self.fig = go.Figure()
self.__create_yaxis()
def __create_traces(self):
x = [1, 2, 3, 4, 4]
temp_num = 0
traces = []
for key, value in self.data.items():
trace = go.Scatter(x=x, y=value, name=str(key), yaxis=self.Y_name[temp_num])
traces.append(trace)
temp_num += 1
return traces
def __create_layout(self):
layout = {
"xaxis": dict(domain=[0.2, 1]),
"legend": dict(orientation="h", x=0, yanchor="bottom", y=1.12, font=dict(size=8, color="black"))
}
keys = list(self.data.keys())
y_number = len(keys)
step = round(0.2 / y_number, 2)
position = 0
temp_i = 1
for i, key in enumerate(keys):
if i == 0:
color = random_color()
layout["yaxis"] = dict(title=str(key), titlefont=dict(size=10), showline=True,
linewidth=2, linecolor=color, ticks="outside")
else:
yaxis = "yaxis{}".format(temp_i)
position += step
color = random_color()
layout[yaxis] = dict(title=str(key), anchor="free", overlaying='y', side="left", position=position,
titlefont=dict(size=10, color=color), showline=True, linewidth=2,
linecolor=color, ticks="outside"
)
temp_i += 1
return layout
def update_layout(self, multiple=True):
if multiple:
layout = self.__create_layout()
else:
layout = {
"xaxis": dict(domain=[0, 1]),
"legend": dict(orientation="h", x=0, yanchor="bottom", y=1.12, font=dict(size=8, color="black"))
}
return layout
def __create_yaxis(self):
y_number = len(self.data)
temp_i = 1
for i in range(y_number):
if i == 0:
self.Y_name.append("y")
else:
self.Y_name.append("y{}".format(temp_i))
temp_i += 1
def update_yaxis(self, multiple=True):
self.Y_name.clear()
if multiple:
self.__create_yaxis()
else:
self.Y_name = ["y"] * len(self.data.keys())
def create_fig(self):
traces = self.__create_traces()
layout = self.__create_layout()
self.fig = go.Figure(data=traces, layout=layout)
return self.fig
def update_fig(self, multiple=True):
self.update_yaxis(multiple=multiple)
self.update_layout(multiple=multiple)
return self.create_fig()
class Window(QtWidgets.QWidget):
def __init__(self):
super(Window, self).__init__()
self.f = Fig(multiple=False)
self.fig = self.f.create_fig()
self.view = PlotlyViewer(parent=None, fig=self.fig)
self.btn = QtWidgets.QPushButton(self)
self.btn.setText("切换轴")
self.btn.clicked.connect(self.switch)
layout = QtWidgets.QHBoxLayout()
layout.addWidget(self.view)
layout.addWidget(self.btn)
self.setLayout(layout)
self.flag = True
self.resize(600, 400)
def switch(self):
if self.flag:
self.flag = False
else:
self.flag = True
self.fig = self.f.update_fig(multiple=self.flag)
self.view.set_figure(self.fig)
def random_color():
colorArr = ['1', '2', '3', '4', '5', '6', '7', '8', '9', 'A', 'B', 'C', 'D', 'E', 'F']
color = ""
for i in range(6):
color += colorArr[random.randint(0, 14)]
return "#" + color
if __name__ == '__main__':
app = QtWidgets.QApplication(sys.argv)
win = Window()
win.show()
# pv.show()
app.exec_()