'''
画图,输入多条数据,进行可视化对比。建议10以内。
'''
import plotly.graph_objs as go
import numpy as np
from plotly.subplots import make_subplots
def plot_various_chns(data,names = []):
'''
画图,输入多条数据,进行可视化对比。建议10以内。
@Input:
data: list,要画的数据
name: 每条数据的名字
@Output:
null
'''
fig = go.Figure()
fig = make_subplots(rows=len(data),cols=1,vertical_spacing=0.05,subplot_titles=names)
for i in range(len(data)):
a = range(data[i].shape[0])
a = np.array(a)
fig.append_trace(go.Scatter(
x = a,
y = data[i],
line = dict(shape = 'spline' ),
),row=i+1,col=1)
fig.update_layout(autosize=False,width=980,height=800)
fig.show()
调用
import numpy as np
from plot_eegs_chn import plot_various_chns
# 先读入数据
x1 = np.load('eeg文件')
l1 = np.load('label文件')
# 选出要被展示的数据,同sub,同label
s = []
num = 0 # 要多少条
for i in range(x1.shape[0]): # 遍历挑label一样的数据
if l1[i]==1:
s.append(x1[i][:,3]) # 这里随便选了数据里的第三个通道
num +=1
if num >= 6:
break
# 调用函数
plot_various_chns(s)