23、EEGNex:处理脑电信号EEG,论文解读+代码实现+BCI IV2a\2b数据结果

EEGNex论文和模型详细的解读:

9、EEGNex:可靠的EEG信号解码模型-CSDN博客

简言之,EEGNex——足够媲美EEGNet的专门用于处理EEG信号的CNN模型,在多个数据集和Moabb数据上实现了SOTA水平!本人对该模型的pytorch实现和在bci iv2a、2b数据上的测试结果如下:

1、代码:

EEGNex_Modile.py:可动态调参配置!

import torch.nn as nn
import torch

class conv(nn.Module):
    def __init__(self, *args) -> None:
        super().__init__()
        if len(args) < 2:
            print('卷积层至少要给出输入与输出的通道数')
            exit()
        else:
            in_channel = args[0]
            out_channel = args[1]
            k = tuple(args[2][0])
            s = args[2][1]
            p = args[2][2]
            b = args[2][3]
        self.conv1 = nn.Sequential(nn.Conv2d(in_channels=in_channel,out_channels=out_channel,kernel_size=k,stride=s,padding=p,bias=b),
                                   nn.BatchNorm2d(out_channel))
    def forward(self,x):
        return self.conv1(x)
    
class dilation_conv1(nn.Module):
    def __init__(self, *args) -> None:
        super().__init__()
        if len(args) < 2:
            print('卷积层至少要给出输入与输出的通道数')
            exit()
        else:
            in_channel = args[0]
            out_channel = args[1]
            k = tuple(args[2][0])
            s = args[2][1]
            p = args[2][2]
            b = args[2][3]
        self.conv1 = nn.Sequential(nn.Conv2d(in_channels=in_channel,out_channels=out_channel,kernel_size=k,stride=s,padding=p,bias=b,dilation=(1,2)),
                                   nn.BatchNorm2d(out_channel))
    def forward(self,x):
        return self.conv1(x)

class dilation_conv2(nn.Module):
    def __init__(self, *args) -> None:
        super().__init__()
        if len(args) < 2:
            print('卷积层至少要给出输入与输出的通道数')
            exit()
        else:
            in_channel = args[0]
            out_channel = args[1]
            k = tuple(args[2][0])
            s = args[2][1]
            p = args[2][2]
            b = args[2][3]
        self.conv1 = nn.Sequential(nn.Conv2d(in_channels=in_channel,out_channels=out_channel,kernel_size=k,stride=s,padding=p,bias=b,dilation=(1,4)),
                                   nn.BatchNorm2d(out_channel))
    def forward(self,x):
        return self.conv1(x)

class DepthConv(nn.Module):
    def __init__(self, *args) -> None:
        super().__init__()
        self.in_channel = args[0]
        self.out_channel = args[1]
        k = tuple(args[2][0])
        s = args[2][1]
        p = args[2][2]
        b = args[2][3]
        self.conv = nn.Sequential(nn.Conv2d(in_channels=self.in_channel,out_channels=self.in_channel,kernel_size=k,stride=s,
                                   padding=p,bias=b,groups=self.in_channel),
                                  nn.BatchNorm2d(self.in_channel))
                                  
                    
    def forward(self,x):
        circle_num = int(self.out_channel / self.in_channel)
        out = []
        for i in range(circle_num):
            out.append(self.conv(x))
        out = torch.concat(tuple(out),dim = 1)
        return out

class SeparableConv(nn.Module):
    def __init__(self, *args) -> None:
        super().__init__()
        if len(args) < 2:
            print('卷积层至少要给出输入与输出的通道数')
            exit()
        else:
            in_channel = args[0]
            out_channel = args[1]
            k = tuple(args[2][0])
            s = args[2][1]
            p = args[2][2]
            b = args[2][3]
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels=in_channel,
                      out_channels=in_channel,
                      kernel_size= k,
                      stride=s,
                      groups=in_channel,
                      padding = p,
                      bias=b),
            nn.Conv2d(in_channels=in_channel,
                      out_channels=out_channel,
                      kernel_size=1,
                      stride=s,
                      padding = p,
                      bias=b
                      ),
            nn.BatchNorm2d(out_channel)
        )
    def forward(self,x):
        return self.conv(x)
    
import torch.nn.functional as F

class Activation(nn.Module):
    def __init__(self, *args) -> None:
        super().__init__()
        if args[0] == 'ELU':
            self.act = 'ELU'
    
    def forward(self,x):
        if self.act == 'ELU':
            return F.elu(x)

class Pool(nn.Module):
    def __init__(self, *args) -> None:
        super().__init__()
        if args[0] == 'AVG':
            k = tuple(args[1])
            self.pool = nn.AvgPool2d(kernel_size=k)
    def forward(self,x):
        return self.pool(x)
   
class Batchnorm(nn.Module):
     def __init__(self, *args) -> None:
          super().__init__()
          b = args[0]
          self.bn = nn.BatchNorm2d(b = b)
     def forward(self,x):
          return self.bn(x)
               
class Dropout(nn.Module):
    def __init__(self, *args) -> None:
        super().__init__()
        p = args[0]
        self.drop = nn.Dropout2d(p = p)
    def forward(self,x):
        return self.drop(x)

class FL(nn.Module):
    def __init__(self, *args) -> None:
        super().__init__()
        self.dim = args[0]
    def forward(self,x):
        out = x.flatten(self.dim)
        return out

class FC(nn.Module):
    def __init__(self, *args) -> None:
        super().__init__()
        in_ch = args[0]
        out_ch = args[1]
        b = args[2]
        self.fc = nn.Linear(in_features=in_ch,out_features=out_ch,bias=b)

    def forward(self,x):
        return self.fc(x)

class SoftMax(nn.Module):
    def __init__(self, *args) -> None:
        super().__init__()
        self.dim = args[0]
    def forward(self,x):
        out = F.softmax(x,dim = self.dim)
        return out

def choose_out_ch(param,params):
    if isinstance(param,str):        # 判断输出通道的是什么形式体现的,
        out_channels = params[param]
    elif isinstance(param,list):
        out_channels = 1
        for p in param:
            out_channels *= params[p]
    elif isinstance(param,int):
        out_channels = param
    else:
        out_channels = None
        print('给模型的输出必须是str型,list型,或者int型')
        exit()
    return out_channels

def parse_model(yaml_cfg):
    layer = []
    input_ch = yaml_cfg['params']['ch']
    EEG_ch = yaml_cfg['params']['C']                # EEG的通道数
    class_num = yaml_cfg['params']['num_class']
    F1 = yaml_cfg['params']['F1']
    F2 = yaml_cfg['params']['F2']
    #F3 = yaml_cfg['params']['F3']
    #dilation1 = yaml_cfg['params']['dilation1']
    #dilation2 = yaml_cfg['params']['dilation2']
    
    ch = [input_ch]
    for i, (f, Module_name, args) in enumerate(yaml_cfg['backbon']):
        '''
        f 是上一层的通道数,yaml_cfg
        Mdule_name: 是执行该层的名字
        args 里是类似kernel_size的参数

        '''  
        m = eval(Module_name) if isinstance(Module_name, str) else Module_name

        if m in [FC]:
            a = 1

        if m in [FL]:
            a = 1
        try:
            if m in [conv,DepthConv,SeparableConv,FC,dilation_conv2,dilation_conv1]:
                if f == -1:
                    in_channels = ch[f]
                else:
                    in_channels = f
                out_channels = choose_out_ch(args[0],yaml_cfg['params'])
                param = [in_channels, out_channels, args[1:]]  # args=[in_channels, out_channels, k, s, p]
            elif m in [Activation,Pool,Dropout,FL,SoftMax]:
                param = args
            elif m in [Batchnorm]:
                param = [ch[-1]]
        except:
            a = 1
        
        model_ = m(*param)
        args.clear()
        ch.append(out_channels)
        layer.append(model_)

    return nn.Sequential(*layer)

from copy import deepcopy

class EEGNex(nn.Module):
     def __init__(self) -> None:
          super().__init__()
          cfg = r'EEGNex_config.yaml'
          self.yaml = cfg
          import yaml
          with open(cfg,errors='ignore') as f:
               self.yaml = yaml.safe_load(f)
          self.backbone = parse_model(deepcopy(self.yaml))
          
     def forward(self,x):
          return self.backbone(x)
          
               
     

EEGNex_config.yaml:在这个文件中进行配置参数!

这里fc的256输出量要改一下


params:
  {
    'ch':1,                   # 输入神经网络的feature map的数量
    'C':22,                   # EEG 脑电信号的通道
    'num_class':4,            # 分类的类别
    'F1':8,                   # 
    'F2':32,                  # 
    'D':2                     # EEGNet论文里block1中的D参数
  }



backbon:
  #block1
  [[-1,conv,[F1,[1,128],1,same,False]], #conv包含了BN
   [-1,Activation,[ELU]],
   [-1,conv,[F2,[1,128],1,same,False]],
  #block2
   [-1,DepthConv,[[D,F2],[22,1],1,valid,False]],
   [-1,Activation,[ELU]],
   [-1,Pool,[AVG,[1,4]]],
   [-1,Dropout,[0.25]],
  #block3
   [-1,dilation_conv1,[F2,[1,32],1,same,False]], 
   [-1,dilation_conv2,[F1,[1,32],1,same,False]],
   [-1,Activation,[ELU]],
   [-1,Pool,[AVG,[1,8]]], 
   [-1,Dropout,[0.25]],

   [-1,FL,[1]],
   [256,FC,[num_class,False]],
   [-1,SoftMax,[1]]]

2、2a、2b结果:

2a:

2b:

给个关注吧~后续更新其他模型处理EEG各个数据哦

  • 20
    点赞
  • 26
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 1
    评论
电信号(Electroencephalogram,EEG)是记录大活动的一种常用方法。在使用PyTorch处理EEG数据时,可以采取以下步骤: 1. 数据处理:对EEG数据进行预处理以去除噪声和伪迹,例如使用滤波器进行陷波滤波、去除眼电伪迹等。PyTorch提供了各种信号处理工具,如torchvision.transforms等。 2. 特征提取:从EEG信号中提取有用的特征,例如使用时频分析方法(如短时傅里叶变换、小波变换)获取时频域特征。可以使用PyTorch提供的信号处理库(如torch.fft)进行频域分析。 3. 数据标准化:对EEG数据进行标准化处理,使其具有相似的分布和范围。可以使用PyTorch的torch.nn.BatchNorm1d或torchvision.transforms.Normalize进行数据标准化。 4. 构建模型:使用PyTorch构建适合EEG数据处理的模型,例如卷积神经网络(Convolutional Neural Networks,CNN)或循环神经网络(Recurrent Neural Networks,RNN)。可以使用PyTorch的torch.nn模块构建模型,并使用torch.optim模块选择优化器。 5. 模型训练:将预处理后的EEG数据输入模型,并使用PyTorch的torch.nn模块定义损失函数,然后使用优化器进行模型训练。可以使用PyTorch的torch.utils.data.Dataset和torch.utils.data.DataLoader加载和处理EEG数据集。 6. 模型评估:使用预留的测试数据对模型进行评估,计算准确率、精确率、召回率等指标。可以使用PyTorch提供的评估工具,如torchmetrics等。 以上是处理EEG数据的基本步骤,具体的实现方法和流程可以根据具体任务和数据集进行调整和修改。希望对你有所帮助!

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

是馒头阿

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值