EEGNet 处理SEED数据集

需要的库
import mne
import os
import glob
import numpy as np
import torch
import scipy.io as sio

 

定义父类
class LoadData:
    def __init__(self,eeg_file_path: str):
        self.eeg_file_path = eeg_file_path

    def load_raw_data_gdf(self,file_to_load):
        self.raw_eeg_subject = mne.io.read_raw_gdf(self.eeg_file_path + '/' + file_to_load)
        return self

    def load_raw_data_mat(self,file_to_load):
        self.raw_eeg_subject = sio.loadmat(self.eeg_file_path + '/' + file_to_load)

    def get_all_files(self,file_path_extension: str =''):
        if file_path_extension:
            return glob.glob(self.eeg_file_path+'/'+file_path_extension)
        return os.listdir(self.eeg_file_path)

定义对应SEED数据集的子类

class LoadSeed(LoadData):
    '''Subclass of LoadData for loading Seed Dataset '''
    def __init__(self,eeg_file_path, *args):
        self.directory_to_load = eeg_file_path
        self.fs = 200
        self.ch_names = ['FP1', 'FPZ', 'FP2', 'AF3', 'AF4', 'F7', 'F5', 'F3', 'F1',
            'FZ', 'F2', 'F4', 'F6', 'F8', 'FT7', 'FC5', 'FC3', 'FC1',
            'FCZ', 'FC2', 'FC4', 'FC6', 'FT8', 'T7', 'C5', 'C3', 'C1',
            'CZ', 'C2', 'C4', 'C6', 'T8', 'TP7', 'CP5', 'CP3', 'CP1',
            'CPZ', 'CP2', 'CP4', 'CP6', 'TP8', 'P7', 'P5', 'P3', 'P1',
            'PZ', 'P2', 'P4', 'P6', 'P8', 'PO7', 'PO5', 'PO3', 'POZ',
            'PO4', 'PO6', 'PO8', 'CB1', 'O1', 'OZ', 'O2', 'CB2']
        self.basic_label = [1, 0, -1, -1, 0, 1, -1, 0, 1, 1, 0, -1, 0, 1, -1]
        super(LoadSeed,self).__init__(*args)
定义读取SEED单个mat文件的方法
def get_epoch(self, tmin=5, tmax=180, baseline=None, downsampled=None):
        self.load_raw_data_mat(self.file_to_load)
        raw_data = self.raw_eeg_subject
        if downsampled is not None:
            raw_data.resample(sfreq=downsampled)
        raw_data = read_one_file(raw_data, self.ch_names)
        events = np.column_stack((np.arange(0,15*35001,35001),
                          np.zeros(15,dtype=int),
                          self.basic_label))
        stims = dict(negtive=-1, neutral=0, positive=1)
        info = mne.create_info(ch_names=self.ch_names, sfreq=self.fs, ch_types='eeg')
        epochs = mne.EpochsArray(raw_data, info, events, event_id=stims)
        y_labels = epochs.events[:, -1] - min(epochs.events[:, -1])
        x_data = epochs.get_data()*1e6
        eeg_data={'x_data':x_data[:,:,:-1],
                  'y_labels':y_labels,
                  'fs':self.fs}
        return eeg_data
定义读取SEED多个mat文件的方法(不跨被试)
# 读入多个文件
    def get_epochs(self, tmin=5, tmax=180, baseline=None, downsampled=None):
        # Get a list of files starting with "1_" in the specified directory
        files_to_load = glob.glob(os.path.join(self.directory_to_load, '4_*.mat'))
        extended_labels = np.concatenate([self.basic_label] * 3)
        # Initialize an empty array to store data from all files
        all_raw_data = []

        # Loop through each file and load data
        for file_to_load in files_to_load:
            self.raw_eeg_subject = sio.loadmat(file_to_load)
            raw_data = self.raw_eeg_subject

            if downsampled is not None:
                raw_data.resample(sfreq=downsampled)

            raw_data = read_one_file(raw_data, self.ch_names)
            all_raw_data.append(raw_data)

        # Concatenate data from all files
        raw_data = np.concatenate(all_raw_data, axis=0)

        # Rest of the method remains unchanged...
        if downsampled is not None:
            raw_data.resample(sfreq=downsampled)
        events = np.column_stack((np.arange(0,45*35001,35001),
                          np.zeros(45,dtype=int),
                          extended_labels))
        stims = dict(negtive=-1, neutral=0, positive=1)
        info = mne.create_info(ch_names=self.ch_names, sfreq=self.fs, ch_types='eeg')
        epochs = mne.EpochsArray(raw_data, info, events, event_id=stims)
        y_labels = epochs.events[:, -1] - min(epochs.events[:, -1])
        x_data = epochs.get_data()*1e6
        eeg_data={'x_data':x_data[:,:,:-1],
                  'y_labels':y_labels,
                  'fs':self.fs}
        return eeg_data
类中引用的将原始脑电数据切分变形成模型输入格式的方法,有一部分代码冗余可自行修改
def read_one_file(data, ch_names):
    freq=200
    tmin=5
    tmax=180
    # 获取keys并获取数据所在key
    keys = list(data.keys())[3:]
    ## 获取数据
    for i in range(len(keys)):
    # 获取数据
        stamp = data[keys[i]]
        # print(stamp.shape)
        # 创建info
        info = mne.create_info(ch_names=ch_names, sfreq=freq, ch_types='eeg')
        # 创建raw,取第5秒开始的数据
        raw = mne.io.RawArray(stamp, info).crop(tmin, tmax)
        raw = np.array(raw.get_data())
        raw = raw[np.newaxis, :, :]
        if i == 0:
            raw_eeg = raw
        else:
            raw_eeg = np.concatenate((raw_eeg, raw), axis=0)
    return raw_eeg
使用示例
'''for seed Datasets'''
data_path = "PATH/TO/YOUR/FOLD"
seed_data = loaddata.LoadSeed(data_path)
eeg_data = seed_data.get_epochs(tmin=5, tmax=150, baseline=None, downsampled=None)

### 使用 MATLAB 处理和分析 SEED 数据集 SEED 数据集是一个公开的情感脑电图 (EEG) 数据库,包含了来自多个受试者在观看不同情感视频片段时记录的 EEG 数据。以下是关于如何使用 MATLAB 对该数据集进行处理和分析的具体方法。 #### 加载并预览数据 首先需要加载 .mat 文件中的原始 EEG 数据以及对应的标签信息。假设文件名为 `data.mat` 和 `label.mat`: ```matlab % Load the data and label files. load('data.mat'); % Contains raw EEG signals from multiple trials. load('label.mat'); % Contains labels indicating emotional states. disp(size(data)); % Display dimensions of loaded dataset to understand its structure better. disp(unique(label)); % Show unique values present within 'labels' array. ``` #### 预处理阶段 对信号执行必要的清理操作,比如去除噪声、滤波和平滑化等步骤可以提高后续特征提取的效果。这里采用带通滤波器来保留特定频率范围内的成分: ```matlab fs = 200; % Sampling frequency is set at 200 Hz according to documentation. lowcut = 1; highcut = 45; [b,a] = butter(5,[lowcut/(fs/2), highcut/(fs/2)],'bandpass'); filteredData = filtfilt(b, a, data); ``` #### 特征工程 计算统计量作为输入给机器学习模型训练的基础特性。例如均值(mean),标准差(stddeviation),偏度(skewness)等等: ```matlab features.meanValue = mean(filteredData, 2); features.stdDev = std(filteredData,[],2); features.skewness = skewness(filteredData); featureMatrix = [features.meanValue', features.stdDev', features.skewness']; ``` #### 构建分类器 利用上述得到的特征矩阵与已知类别标签一起训练支持向量机(SVM): ```matlab SVMModel = fitcsvm(featureMatrix, label,'KernelFunction','rbf',... 'Standardize',true,... 'ClassNames',[-1 0 1]); predictedLabels = predict(SVMModel, featureMatrix); confusionchart(predictedLabels,label); ``` 以上展示了基本的工作流程,在实际研究过程中可能还需要考虑更多细节方面的问题,如交叉验证评估性能指标的选择等。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值