EEG | EEGNet 神经网络分类脑电信号实战(附完整源码)

EEGNet + MNE 分类 Sample数据集

一、环境配置

Package nameVersion
Python3.7
Tensorflow2.7.0
mne0.24.1
matplotlib、scikit-learn、numpy

关于 mne ,可以参考我的另一篇博客:MNE-Python | 开源生理信号分析神器(一)

二、数据集介绍

2.1 数据采集

数据集是通过位于MGH/HMS/MIT(麻省总医院)的Athinoula A. Martino 生物医学成像中心的Neuromag Vectorview 系统获得的。同时采集60 通道电极帽的MEG(脑磁图)数据。原始MRI(核磁共振)数据集通过MPRAGE 序列的西门子1.5 T Sonata 扫描仪获取的。数据集下载

2.2 实验设计

在实验中,受试者的左右视野中会出现棋盘图案,同时会伴随出现在左右耳的音调,刺激间隔为750 ms。此外,在受试者的视野中心会随机出现笑脸图案,受试者被要求在笑脸出现后尽快用右手食指按下按键。实验中刺激和响应的对应关系如图所示。
在这里插入图片描述

2.3 数据集目录

Sample 数据集主要包含两个部分:MEG/sample (MEG/EEG 数据)和来自另一位受试者的MRI 重建数据 subjects/sample ,我使用的主要是MEG/sample,其目录如图所示。
在这里插入图片描述

三、网络模型

3.1 网络结构

在这里插入图片描述

3.2代码实现
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Dense, Activation, Permute, Dropout
from tensorflow.keras.layers import Conv2D, MaxPooling2D, AveragePooling2D
from tensorflow.keras.layers import SeparableConv2D, DepthwiseConv2D
from tensorflow.keras.layers import BatchNormalization
from tensorflow.keras.layers import SpatialDropout2D
from tensorflow.keras.regularizers import l1_l2
from tensorflow.keras.layers import Input, Flatten
from tensorflow.keras.constraints import max_norm
from tensorflow.keras import backend as K


def EEGNet(nb_classes, Chans = 64, Samples = 128, 
             dropoutRate = 0.5, kernLength = 64, F1 = 8, 
             D = 2, F2 = 16, norm_rate = 0.25, dropoutType = 'Dropout'):
    
    if dropoutType == 'SpatialDropout2D':
        dropoutType = SpatialDropout2D
    elif dropoutType == 'Dropout':
        dropoutType = Dropout
    else:
        raise ValueError('dropoutType must be one of SpatialDropout2D '
                         'or Dropout, passed as a string.')
                         
    input1   = Input(shape = (Chans, Samples, 1))

    block1       = Conv2D(F1, (1, kernLength), padding = 'same',
                                   input_shape = (Chans, Samples, 1),
                                   use_bias = False)(input1)
    block1       = BatchNormalization()(block1)
    block1       = DepthwiseConv2D((Chans, 1), use_bias = False, 
                                   depth_multiplier = D,
                                   depthwise_constraint = max_norm(1.))(block1)
    block1       = BatchNormalization()(block1)
    block1       = Activation('elu')(block1)
    block1       = AveragePooling2D((1, 4))(block1)
    block1       = dropoutType(dropoutRate)(block1)
    
    block2       = SeparableConv2D(F2, (1, 16),
                                   use_bias = False, padding = 'same')(block1)
    block2       = BatchNormalization()(block2)
    block2       = Activation('elu')(block2)
    block2       = AveragePooling2D((1, 8))(block2)
    block2       = dropoutType(dropoutRate)(block2)
        
    flatten      = Flatten(name = 'flatten')(block2)
    
    dense        = Dense(nb_classes, name = 'dense', 
                         kernel_constraint = max_norm(norm_rate))(flatten)
    softmax      = Activation('softmax', name = 'softmax')(dense)
    
    return Model(inputs=input1, outputs=softmax)

四、分类实战

4.1 数据集预处理
def get_data4EEGNet():

    K.set_image_data_format('channels_last')
	
	# 数据集存放路径(替换成你自己的) 
    data_path = '/Users/XXX/XXX/' 

    raw_fname = data_path + '/MEG/sample/sample_audvis_filt-0-40_raw.fif'
    event_fname = data_path + '/MEG/sample/sample_audvis_filt-0-40_raw-eve.fif'
    
    tmin, tmax = -0., 1
    event_id = dict(aud_l=1, aud_r=2, vis_l=3, vis_r=4)

    kernels, chans, samples = 1, 60, 151

    raw = io.Raw(raw_fname, preload=True, verbose=False)
    raw.filter(2, None, method='iir')  
    events = mne.read_events(event_fname)

    raw.info['bads'] = ['MEG 2443']  
    picks = mne.pick_types(raw.info, meg=False, eeg=True, stim=False, eog=False,
                           exclude='bads')

    epochs = mne.Epochs(raw, events, event_id, tmin, tmax, proj=False,
                        picks=picks, baseline=None, preload=True, verbose=False)
    labels = epochs.events[:, -1]

    X = epochs.get_data()*1000    # (288,60,151)——(trail数,通道数,采样频率)  
    y = labels					  # (288,4)

    # 划分训练集、验证集和测试集
    X_train      = X[0:144,]
    Y_train      = y[0:144]
    X_validate   = X[144:216,]
    Y_validate   = y[144:216]
    X_test       = X[216:,]
    Y_test       = y[216:]

    Y_train      = np_utils.to_categorical(Y_train-1)
    Y_validate   = np_utils.to_categorical(Y_validate-1)
    Y_test       = np_utils.to_categorical(Y_test-1)

    X_train      = X_train.reshape(X_train.shape[0], chans, samples, kernels)
    X_validate   = X_validate.reshape(X_validate.shape[0], chans, samples, kernels)
    X_test       = X_test.reshape(X_test.shape[0], chans, samples, kernels)

    return X_train, X_validate, X_test, Y_train, Y_validate, Y_test
4.2 完整代码
import numpy as np

import mne
from mne import io
from mne.datasets import sample

from tensorflow.keras import utils as np_utils
from tensorflow.keras.callbacks import ModelCheckpoint
from tensorflow.keras import backend as K
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Dense, Activation, Permute, Dropout
from tensorflow.keras.layers import Conv2D, MaxPooling2D, AveragePooling2D
from tensorflow.keras.layers import SeparableConv2D, DepthwiseConv2D
from tensorflow.keras.layers import BatchNormalization
from tensorflow.keras.layers import SpatialDropout2D
from tensorflow.keras.regularizers import l1_l2
from tensorflow.keras.layers import Input, Flatten
from tensorflow.keras.constraints import max_norm

from sklearn.pipeline import make_pipeline
from sklearn.linear_model import LogisticRegression

from matplotlib import pyplot as plt

import pathlib

def EEGNet(nb_classes, Chans = 64, Samples = 128, 
             dropoutRate = 0.5, kernLength = 64, F1 = 8, 
             D = 2, F2 = 16, norm_rate = 0.25, dropoutType = 'Dropout'):

    
    if dropoutType == 'SpatialDropout2D':
        dropoutType = SpatialDropout2D
    elif dropoutType == 'Dropout':
        dropoutType = Dropout
    else:
        raise ValueError('dropoutType must be one of SpatialDropout2D '
                         'or Dropout, passed as a string.')
    
    input1   = Input(shape = (Chans, Samples, 1))

    block1       = Conv2D(F1, (1, kernLength), padding = 'same',
                                   input_shape = (Chans, Samples, 1),
                                   use_bias = False)(input1)
    block1       = BatchNormalization()(block1)
    block1       = DepthwiseConv2D((Chans, 1), use_bias = False, 
                                   depth_multiplier = D,
                                   depthwise_constraint = max_norm(1.))(block1)
    block1       = BatchNormalization()(block1)
    block1       = Activation('elu')(block1)
    block1       = AveragePooling2D((1, 4))(block1)
    block1       = dropoutType(dropoutRate)(block1)
    
    block2       = SeparableConv2D(F2, (1, 16),
                                   use_bias = False, padding = 'same')(block1)
    block2       = BatchNormalization()(block2)
    block2       = Activation('elu')(block2)
    block2       = AveragePooling2D((1, 8))(block2)
    block2       = dropoutType(dropoutRate)(block2)
        
    flatten      = Flatten(name = 'flatten')(block2)
    
    dense        = Dense(nb_classes, name = 'dense', 
                         kernel_constraint = max_norm(norm_rate))(flatten)
    softmax      = Activation('softmax', name = 'softmax')(dense)
    
    return Model(inputs=input1, outputs=softmax)

def get_data4EEGNet(kernels, chans, samples):

    K.set_image_data_format('channels_last')

    data_path = '/Users/XXX'

    raw_fname = data_path + '/MEG/sample/sample_audvis_filt-0-40_raw.fif'
    event_fname = data_path + '/MEG/sample/sample_audvis_filt-0-40_raw-eve.fif'
    
    tmin, tmax = -0., 1
    event_id = dict(aud_l=1, aud_r=2, vis_l=3, vis_r=4)

    raw = io.Raw(raw_fname, preload=True, verbose=False)
    raw.filter(2, None, method='iir')  
    events = mne.read_events(event_fname)

    raw.info['bads'] = ['MEG 2443']  
    picks = mne.pick_types(raw.info, meg=False, eeg=True, stim=False, eog=False,
                           exclude='bads')

    epochs = mne.Epochs(raw, events, event_id, tmin, tmax, proj=False,
                        picks=picks, baseline=None, preload=True, verbose=False)
    labels = epochs.events[:, -1]


    X = epochs.get_data()*1000 
    y = labels

    X_train      = X[0:144,]
    Y_train      = y[0:144]
    X_validate   = X[144:216,]
    Y_validate   = y[144:216]
    X_test       = X[216:,]
    Y_test       = y[216:]


    Y_train      = np_utils.to_categorical(Y_train-1)
    Y_validate   = np_utils.to_categorical(Y_validate-1)
    Y_test       = np_utils.to_categorical(Y_test-1)


    X_train      = X_train.reshape(X_train.shape[0], chans, samples, kernels)
    X_validate   = X_validate.reshape(X_validate.shape[0], chans, samples, kernels)
    X_test       = X_test.reshape(X_test.shape[0], chans, samples, kernels)

    return X_train, X_validate, X_test, Y_train, Y_validate, Y_test

kernels, chans, samples = 1, 60, 151

X_train, X_validate, X_test, Y_train, Y_validate, Y_test = get_data4EEGNet(kernels, chans, samples)

model = EEGNet(nb_classes = 4, Chans = chans, Samples = samples, 
               dropoutRate = 0.5, kernLength = 32, F1 = 8, D = 2, F2 = 16, 
               dropoutType = 'Dropout')

model.compile(loss='categorical_crossentropy', optimizer='adam', 
              metrics = ['accuracy'])

checkpointer = ModelCheckpoint(filepath='/Users/XXX/baseline.h5', verbose=1,
                                save_best_only=True)

class_weights = {0:1, 1:1, 2:1, 3:1}

fittedModel = model.fit(X_train, Y_train, batch_size = 16, epochs = 300, 
                        verbose = 2, validation_data=(X_validate, Y_validate),
                        callbacks=[checkpointer], class_weight = class_weights)

model.load_weights('./SaveModel/baseline.h5')

probs       = model.predict(X_test)
preds       = probs.argmax(axis = -1)  
acc         = np.mean(preds == Y_test.argmax(axis=-1))
print("Classification accuracy: %f " % (acc))
4.3 运行结果

在这里插入图片描述

五、参考资料

论文链接: EEGNet: a compact convolutional neural network for EEG-based brain–computer interfaces(Journal of Neural Engineering,SCI JCR2,Impact Factor:4.141)
Github链接: the Army Research Laboratory (ARL) EEGModels project

### 关于EEGNet的相关参考文献 EEGNet 是一种专门为脑电图(Electroencephalogram, EEG)信号设计的轻量级深度学习模型,旨在通过减少参数数量来提高训练效率并保持高精度。以下是关于 EEGNet 及其在深度学习中应用的一些重要参考资料: #### 1. **EEGNet 的原始论文** EEGNet 首次被提出是在 Lawhern 等人的工作中,该工作详细介绍了一种高效的卷积神经网络架构,专门针对 EEG 数据进行了优化[^5]。 ```plaintext Lawhern V J, Solon A J, Waytowich N R, et al. EEGNet: a compact convolutional neural network for EEG-based brain-computer interfaces[J]. Journal of Neural Engineering, 2018. ``` 此论文介绍了 EEGNet 的核心设计理念,包括时间卷积、空间卷积和分离卷积的操作方式,显著减少了计算复杂度的同时提高了性能。 --- #### 2. **EEGNetEEG 信号处理中的扩展研究** 除了原始论文外,还有许多后续的研究进一步探讨了 EEGNet 的改进及其应用场景。例如,在某些研究中,提出了结合注意力机制或迁移学习的方法以增强 EEGNet 对不同任务的表现能力[^6]。 ```plaintext Khan S H, Kamel M S. Transfer learning using EEGNet model with attention mechanism for emotion recognition[C]//International Conference on Artificial Intelligence and Soft Computing. Springer, Cham, 2021. ``` 上述研究表明,通过引入注意机制可以有效提升 EEGNet 在情感识别等特定任务上的表现。 --- #### 3. **与其他深度学习模型的对比分析** 一些综述文章比较了多种深度学习模型在 EEG 信号处理方面的优劣,其中包括 CNN、RNN 和 GAN 等不同类型网络的应用场景及效果评估[^7]。这类研究有助于理解 EEGNet 在整个领域内的定位和发展趋势。 ```plaintext Chen X, Zhang K, Li Y, et al. Deep learning models for electroencephalography (EEG): State-of-the-art approaches and future challenges[J]. Information Fusion, 2020. ``` 这篇综述不仅讨论了传统方法向深度学习过渡的过程,还特别强调了像 EEGNet 这样专为 EEG 设计的小型化网络的重要性。 --- #### 4. **卷积操作的本质差异** 值得注意的是,尽管 EEGNet 使用了卷积运算作为基础组件之一,但在实际实现过程中所采用的形式更接近于互相关而非严格意义上的卷积[^2]。这种区别对于理解和优化 EEGNet 中涉及的具体算子具有重要意义。 --- #### 5. **其他相关资源** 如果希望了解有关 EEG 信号预处理、特征提取等方面的知识,则可参考以下材料: - BCI Competition IV Dataset IIa 提供了一个标准数据集用于验证算法有效性; - Python 工具包 `MNE-Python` 支持高效加载与可视化 EEG 数据文件。 --- ### 示例代码片段:如何使用 EEGNet 处理 EEG 数据? 下面展示一段简单的 PyTorch 实现 EEGNet 的代码示例: ```python import torch.nn as nn class EEGNet(nn.Module): def __init__(self, classes=2, channels=22, samples=1125): super(EEGNet, self).__init__() # Layer 1: Temporal Convolution self.conv_temporal = nn.Sequential( nn.Conv2d(1, 8, kernel_size=(1, 64), padding='same', bias=False), nn.BatchNorm2d(8) ) # Layer 2: Spatial Convolution & Depthwise Separable Convolution self.spatial_conv = nn.Sequential( nn.Conv2d(8, 16, kernel_size=(channels, 1), groups=8, bias=False), nn.BatchNorm2d(16), nn.ELU(), nn.AvgPool2d(kernel_size=(1, 4)), nn.Dropout(p=0.25) ) # Fully Connected Layers self.fc_layers = nn.Sequential( nn.Linear(16 * int(samples / 4), 128), nn.ELU(), nn.Dropout(p=0.5), nn.Linear(128, classes) ) def forward(self, x): out = self.conv_temporal(x.unsqueeze(1)) out = self.spatial_conv(out) out = out.view(out.size(0), -1) out = self.fc_layers(out) return out ``` ---
评论 28
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

一只殿鹿

爱屋及乌(滑稽

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值