EEGNex论文和模型详细的解读:
简言之,EEGNex——足够媲美EEGNet的专门用于处理EEG信号的CNN模型,在多个数据集和Moabb数据上实现了SOTA水平!本人对该模型的pytorch实现和在bci iv2a、2b数据上的测试结果如下:
1、代码:
EEGNex_Modile.py:可动态调参配置!
import torch.nn as nn
import torch
class conv(nn.Module):
def __init__(self, *args) -> None:
super().__init__()
if len(args) < 2:
print('卷积层至少要给出输入与输出的通道数')
exit()
else:
in_channel = args[0]
out_channel = args[1]
k = tuple(args[2][0])
s = args[2][1]
p = args[2][2]
b = args[2][3]
self.conv1 = nn.Sequential(nn.Conv2d(in_channels=in_channel,out_channels=out_channel,kernel_size=k,stride=s,padding=p,bias=b),
nn.BatchNorm2d(out_channel))
def forward(self,x):
return self.conv1(x)
class dilation_conv1(nn.Module):
def __init__(self, *args) -> None:
super().__init__()
if len(args) < 2:
print('卷积层至少要给出输入与输出的通道数')
exit()
else:
in_channel = args[0]
out_channel = args[1]
k = tuple(args[2][0])
s = args[2][1]
p = args[2][2]
b = args[2][3]
self.conv1 = nn.Sequential(nn.Conv2d(in_channels=in_channel,out_channels=out_channel,kernel_size=k,stride=s,padding=p,bias=b,dilation=(1,2)),
nn.BatchNorm2d(out_channel))
def forward(self,x):
return self.conv1(x)
class dilation_conv2(nn.Module):
def __init__(self, *args) -> None:
super().__init__()
if len(args) < 2:
print('卷积层至少要给出输入与输出的通道数')
exit()
else:
in_channel = args[0]
out_channel = args[1]
k = tuple(args[2][0])
s = args[2][1]
p = args[2][2]
b = args[2][3]
self.conv1 = nn.Sequential(nn.Conv2d(in_channels=in_channel,out_channels=out_channel,kernel_size=k,stride=s,padding=p,bias=b,dilation=(1,4)),
nn.BatchNorm2d(out_channel))
def forward(self,x):
return self.conv1(x)
class DepthConv(nn.Module):
def __init__(self, *args) -> None:
super().__init__()
self.in_channel = args[0]
self.out_channel = args[1]
k = tuple(args[2][0])
s = args[2][1]
p = args[2][2]
b = args[2][3]
self.conv = nn.Sequential(nn.Conv2d(in_channels=self.in_channel,out_channels=self.in_channel,kernel_size=k,stride=s,
padding=p,bias=b,groups=self.in_channel),
nn.BatchNorm2d(self.in_channel))
def forward(self,x):
circle_num = int(self.out_channel / self.in_channel)
out = []
for i in range(circle_num):
out.append(self.conv(x))
out = torch.concat(tuple(out),dim = 1)
return out
class SeparableConv(nn.Module):
def __init__(self, *args) -> None:
super().__init__()
if len(args) < 2:
print('卷积层至少要给出输入与输出的通道数')
exit()
else:
in_channel = args[0]
out_channel = args[1]
k = tuple(args[2][0])
s = args[2][1]
p = args[2][2]
b = args[2][3]
self.conv = nn.Sequential(
nn.Conv2d(in_channels=in_channel,
out_channels=in_channel,
kernel_size= k,
stride=s,
groups=in_channel,
padding = p,
bias=b),
nn.Conv2d(in_channels=in_channel,
out_channels=out_channel,
kernel_size=1,
stride=s,
padding = p,
bias=b
),
nn.BatchNorm2d(out_channel)
)
def forward(self,x):
return self.conv(x)
import torch.nn.functional as F
class Activation(nn.Module):
def __init__(self, *args) -> None:
super().__init__()
if args[0] == 'ELU':
self.act = 'ELU'
def forward(self,x):
if self.act == 'ELU':
return F.elu(x)
class Pool(nn.Module):
def __init__(self, *args) -> None:
super().__init__()
if args[0] == 'AVG':
k = tuple(args[1])
self.pool = nn.AvgPool2d(kernel_size=k)
def forward(self,x):
return self.pool(x)
class Batchnorm(nn.Module):
def __init__(self, *args) -> None:
super().__init__()
b = args[0]
self.bn = nn.BatchNorm2d(b = b)
def forward(self,x):
return self.bn(x)
class Dropout(nn.Module):
def __init__(self, *args) -> None:
super().__init__()
p = args[0]
self.drop = nn.Dropout2d(p = p)
def forward(self,x):
return self.drop(x)
class FL(nn.Module):
def __init__(self, *args) -> None:
super().__init__()
self.dim = args[0]
def forward(self,x):
out = x.flatten(self.dim)
return out
class FC(nn.Module):
def __init__(self, *args) -> None:
super().__init__()
in_ch = args[0]
out_ch = args[1]
b = args[2]
self.fc = nn.Linear(in_features=in_ch,out_features=out_ch,bias=b)
def forward(self,x):
return self.fc(x)
class SoftMax(nn.Module):
def __init__(self, *args) -> None:
super().__init__()
self.dim = args[0]
def forward(self,x):
out = F.softmax(x,dim = self.dim)
return out
def choose_out_ch(param,params):
if isinstance(param,str): # 判断输出通道的是什么形式体现的,
out_channels = params[param]
elif isinstance(param,list):
out_channels = 1
for p in param:
out_channels *= params[p]
elif isinstance(param,int):
out_channels = param
else:
out_channels = None
print('给模型的输出必须是str型,list型,或者int型')
exit()
return out_channels
def parse_model(yaml_cfg):
layer = []
input_ch = yaml_cfg['params']['ch']
EEG_ch = yaml_cfg['params']['C'] # EEG的通道数
class_num = yaml_cfg['params']['num_class']
F1 = yaml_cfg['params']['F1']
F2 = yaml_cfg['params']['F2']
#F3 = yaml_cfg['params']['F3']
#dilation1 = yaml_cfg['params']['dilation1']
#dilation2 = yaml_cfg['params']['dilation2']
ch = [input_ch]
for i, (f, Module_name, args) in enumerate(yaml_cfg['backbon']):
'''
f 是上一层的通道数,yaml_cfg
Mdule_name: 是执行该层的名字
args 里是类似kernel_size的参数
'''
m = eval(Module_name) if isinstance(Module_name, str) else Module_name
if m in [FC]:
a = 1
if m in [FL]:
a = 1
try:
if m in [conv,DepthConv,SeparableConv,FC,dilation_conv2,dilation_conv1]:
if f == -1:
in_channels = ch[f]
else:
in_channels = f
out_channels = choose_out_ch(args[0],yaml_cfg['params'])
param = [in_channels, out_channels, args[1:]] # args=[in_channels, out_channels, k, s, p]
elif m in [Activation,Pool,Dropout,FL,SoftMax]:
param = args
elif m in [Batchnorm]:
param = [ch[-1]]
except:
a = 1
model_ = m(*param)
args.clear()
ch.append(out_channels)
layer.append(model_)
return nn.Sequential(*layer)
from copy import deepcopy
class EEGNex(nn.Module):
def __init__(self) -> None:
super().__init__()
cfg = r'EEGNex_config.yaml'
self.yaml = cfg
import yaml
with open(cfg,errors='ignore') as f:
self.yaml = yaml.safe_load(f)
self.backbone = parse_model(deepcopy(self.yaml))
def forward(self,x):
return self.backbone(x)
EEGNex_config.yaml:在这个文件中进行配置参数!
这里fc的256输出量要改一下
params:
{
'ch':1, # 输入神经网络的feature map的数量
'C':22, # EEG 脑电信号的通道
'num_class':4, # 分类的类别
'F1':8, #
'F2':32, #
'D':2 # EEGNet论文里block1中的D参数
}
backbon:
#block1
[[-1,conv,[F1,[1,128],1,same,False]], #conv包含了BN
[-1,Activation,[ELU]],
[-1,conv,[F2,[1,128],1,same,False]],
#block2
[-1,DepthConv,[[D,F2],[22,1],1,valid,False]],
[-1,Activation,[ELU]],
[-1,Pool,[AVG,[1,4]]],
[-1,Dropout,[0.25]],
#block3
[-1,dilation_conv1,[F2,[1,32],1,same,False]],
[-1,dilation_conv2,[F1,[1,32],1,same,False]],
[-1,Activation,[ELU]],
[-1,Pool,[AVG,[1,8]]],
[-1,Dropout,[0.25]],
[-1,FL,[1]],
[256,FC,[num_class,False]],
[-1,SoftMax,[1]]]
2、2a、2b结果:
2a:
2b:
给个关注吧~后续更新其他模型处理EEG各个数据哦