需要的库
import mne
import os
import glob
import numpy as np
import torch
import scipy.io as sio
定义父类
class LoadData:
def __init__(self,eeg_file_path: str):
self.eeg_file_path = eeg_file_path
def load_raw_data_gdf(self,file_to_load):
self.raw_eeg_subject = mne.io.read_raw_gdf(self.eeg_file_path + '/' + file_to_load)
return self
def load_raw_data_mat(self,file_to_load):
self.raw_eeg_subject = sio.loadmat(self.eeg_file_path + '/' + file_to_load)
def get_all_files(self,file_path_extension: str =''):
if file_path_extension:
return glob.glob(self.eeg_file_path+'/'+file_path_extension)
return os.listdir(self.eeg_file_path)
定义对应SEED数据集的子类
class LoadSeed(LoadData):
'''Subclass of LoadData for loading Seed Dataset '''
def __init__(self,eeg_file_path, *args):
self.directory_to_load = eeg_file_path
self.fs = 200
self.ch_names = ['FP1', 'FPZ', 'FP2', 'AF3', 'AF4', 'F7', 'F5', 'F3', 'F1',
'FZ', 'F2', 'F4', 'F6', 'F8', 'FT7', 'FC5', 'FC3', 'FC1',
'FCZ', 'FC2', 'FC4', 'FC6', 'FT8', 'T7', 'C5', 'C3', 'C1',
'CZ', 'C2', 'C4', 'C6', 'T8', 'TP7', 'CP5', 'CP3', 'CP1',
'CPZ', 'CP2', 'CP4', 'CP6', 'TP8', 'P7', 'P5', 'P3', 'P1',
'PZ', 'P2', 'P4', 'P6', 'P8', 'PO7', 'PO5', 'PO3', 'POZ',
'PO4', 'PO6', 'PO8', 'CB1', 'O1', 'OZ', 'O2', 'CB2']
self.basic_label = [1, 0, -1, -1, 0, 1, -1, 0, 1, 1, 0, -1, 0, 1, -1]
super(LoadSeed,self).__init__(*args)
定义读取SEED单个mat文件的方法
def get_epoch(self, tmin=5, tmax=180, baseline=None, downsampled=None):
self.load_raw_data_mat(self.file_to_load)
raw_data = self.raw_eeg_subject
if downsampled is not None:
raw_data.resample(sfreq=downsampled)
raw_data = read_one_file(raw_data, self.ch_names)
events = np.column_stack((np.arange(0,15*35001,35001),
np.zeros(15,dtype=int),
self.basic_label))
stims = dict(negtive=-1, neutral=0, positive=1)
info = mne.create_info(ch_names=self.ch_names, sfreq=self.fs, ch_types='eeg')
epochs = mne.EpochsArray(raw_data, info, events, event_id=stims)
y_labels = epochs.events[:, -1] - min(epochs.events[:, -1])
x_data = epochs.get_data()*1e6
eeg_data={'x_data':x_data[:,:,:-1],
'y_labels':y_labels,
'fs':self.fs}
return eeg_data
定义读取SEED多个mat文件的方法(不跨被试)
# 读入多个文件
def get_epochs(self, tmin=5, tmax=180, baseline=None, downsampled=None):
# Get a list of files starting with "1_" in the specified directory
files_to_load = glob.glob(os.path.join(self.directory_to_load, '4_*.mat'))
extended_labels = np.concatenate([self.basic_label] * 3)
# Initialize an empty array to store data from all files
all_raw_data = []
# Loop through each file and load data
for file_to_load in files_to_load:
self.raw_eeg_subject = sio.loadmat(file_to_load)
raw_data = self.raw_eeg_subject
if downsampled is not None:
raw_data.resample(sfreq=downsampled)
raw_data = read_one_file(raw_data, self.ch_names)
all_raw_data.append(raw_data)
# Concatenate data from all files
raw_data = np.concatenate(all_raw_data, axis=0)
# Rest of the method remains unchanged...
if downsampled is not None:
raw_data.resample(sfreq=downsampled)
events = np.column_stack((np.arange(0,45*35001,35001),
np.zeros(45,dtype=int),
extended_labels))
stims = dict(negtive=-1, neutral=0, positive=1)
info = mne.create_info(ch_names=self.ch_names, sfreq=self.fs, ch_types='eeg')
epochs = mne.EpochsArray(raw_data, info, events, event_id=stims)
y_labels = epochs.events[:, -1] - min(epochs.events[:, -1])
x_data = epochs.get_data()*1e6
eeg_data={'x_data':x_data[:,:,:-1],
'y_labels':y_labels,
'fs':self.fs}
return eeg_data
类中引用的将原始脑电数据切分变形成模型输入格式的方法,有一部分代码冗余可自行修改
def read_one_file(data, ch_names):
freq=200
tmin=5
tmax=180
# 获取keys并获取数据所在key
keys = list(data.keys())[3:]
## 获取数据
for i in range(len(keys)):
# 获取数据
stamp = data[keys[i]]
# print(stamp.shape)
# 创建info
info = mne.create_info(ch_names=ch_names, sfreq=freq, ch_types='eeg')
# 创建raw,取第5秒开始的数据
raw = mne.io.RawArray(stamp, info).crop(tmin, tmax)
raw = np.array(raw.get_data())
raw = raw[np.newaxis, :, :]
if i == 0:
raw_eeg = raw
else:
raw_eeg = np.concatenate((raw_eeg, raw), axis=0)
return raw_eeg
使用示例
'''for seed Datasets'''
data_path = "PATH/TO/YOUR/FOLD"
seed_data = loaddata.LoadSeed(data_path)
eeg_data = seed_data.get_epochs(tmin=5, tmax=150, baseline=None, downsampled=None)