EEG-Inception复现记录

一、引言

这是本人花一天时间复现出来的pytorch版本的EEG-Inception,由于pytorch卷积对通道选定的问题,有些部分代码可能和原论文不一样,稍后会说, 如本文有不妥之处,欢迎指出。

请注意:本文章所分享的代码只供参考。

原论文代码请去《EEG-Inception: A Novel Deep Convolutional Neural Network for Assistive ERP-Based Brain-Computer Interfaces》作者github开源代码上自取。

本篇文章不讲解论文,至给出实现代码,EEG-Inception论文讲解推荐博主:

56、巴利亚多利德大学、马德里卡洛斯三世研究所:EEG-Inception-多时间尺度与空间卷积巧妙交叉堆叠,终达SOTA!-CSDN博客icon-default.png?t=N7T8https://blog.csdn.net/mantoudamahou/article/details/137679180?spm=1001.2014.3001.5502

本次pytorch代码已上传至本人的github上,需要请自取:

python-implementation-of-motion-imagination-classification/EEG-Inception at main · XCZchaos/python-implementation-of-motion-imagination-classification (github.com)icon-default.png?t=N7T8https://github.com/XCZchaos/python-implementation-of-motion-imagination-classification/tree/main/EEG-Inception

二、模型代码

我们直接上代码

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchsummary import summary


class DepthwiseSeparableConv2d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, padding=0):
        super(DepthwiseSeparableConv2d, self).__init__()
        self.depthwise = nn.Conv2d(in_channels, in_channels, kernel_size=kernel_size, padding='valid', groups=in_channels)
        self.pointwise = nn.Conv2d(in_channels, out_channels, kernel_size=1)

    def forward(self, x):
        x = self.depthwise(x)
        x = self.pointwise(x)
        # 进行通道压缩


        return x

class EEGInception(nn.Module):
    def __init__(self, input_time=1000, fs=128, ncha=8, filters_per_branch=8,
                 scales_time=(500, 250, 125), dropout_rate=0.25,
                 activation='relu', n_classes=2):
        super(EEGInception, self).__init__()

        # ============================= CALCULATIONS ============================= #
        input_samples = int(input_time * fs / 1000)
        scales_samples = [int(s * fs / 1000) for s in scales_time]

        # ================================ INPUT ================================= #
        self.input_layer = nn.Conv2d(1, ncha, kernel_size=(1, 1))

        # ========================== BLOCK 1: INCEPTION ========================== #
        b1_units = []
        for i in range(len(scales_samples)):
            unit = nn.Sequential(
                nn.Conv2d(ncha, ncha, kernel_size=(1, scales_samples[i]), padding='same'),
                nn.BatchNorm2d(ncha),
                nn.ELU(inplace=True),
                DepthwiseSeparableConv2d(ncha, ncha*2, kernel_size=(ncha, 1)),
                nn.BatchNorm2d(ncha*2),
                nn.ELU(inplace=True),
                nn.Dropout(dropout_rate)
            )
            b1_units.append(unit)

        self.b1_units = nn.ModuleList(b1_units)

        # ========================== BLOCK 2: INCEPTION ========================== #
        b2_units = []
        for i in range(len(scales_samples)):
            unit = nn.Sequential(
                nn.Conv2d(filters_per_branch*6, filters_per_branch, kernel_size=(int(scales_samples[i]/4), 1), padding='same', padding_mode='zeros'),
                nn.BatchNorm2d(filters_per_branch),
                nn.ELU(inplace=True),
                nn.Dropout(dropout_rate)
            )
            b2_units.append(unit)

        self.b2_units = nn.ModuleList(b2_units)

        # ============================ BLOCK 3: OUTPUT =========================== #
        self.b3_u1 = nn.Sequential(
            nn.Conv2d(filters_per_branch * len(scales_samples), int(filters_per_branch*len(scales_samples)/2), kernel_size=(8, 1),padding='same'),
            nn.BatchNorm2d(int(filters_per_branch*len(scales_samples)/2)),
            nn.ELU(inplace=True),
            nn.AvgPool2d((2, 1)),
            nn.Dropout(dropout_rate)
        )

        self.b3_u2 = nn.Sequential(
            nn.Conv2d(int(filters_per_branch*len(scales_samples)/2), int(filters_per_branch*len(scales_samples)/4), kernel_size=(4, 1),padding='same'),
            nn.BatchNorm2d(int(filters_per_branch*len(scales_samples)/4)),
            nn.ELU(inplace=True),
            nn.AvgPool2d((2, 1)),
            nn.Dropout(dropout_rate)
        )

        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(int(filters_per_branch*len(scales_samples)/4), n_classes)

    def forward(self, x):
        # ================================ INPUT ================================= #


        x = self.input_layer(x)



        # ========================== BLOCK 1: INCEPTION ========================== #
        b1_outputs = [unit(x) for unit in self.b1_units]


        b1_out = torch.cat(b1_outputs, dim=1)

        b1_out = b1_out.permute((0, 1, 3, 2))

        b1_out = F.avg_pool2d(b1_out, (4, 1))
        # b1_out = b1_out.permute((0, 2, 1, 3))




        # ========================== BLOCK 2: INCEPTION ========================== #
        b2_outputs = [unit(b1_out) for unit in self.b2_units]

        b2_out = torch.cat(b2_outputs, dim=1)

        b2_out = F.avg_pool2d(b2_out, (2, 1))


        # ============================ BLOCK 3: OUTPUT =========================== #
        b3_u1_out = F.avg_pool2d(F.elu(self.b3_u1(b2_out)), (2, 1))

        b3_u2_out = F.avg_pool2d(F.elu(self.b3_u2(b3_u1_out)), (2, 1))

        b3_out = self.avgpool(b3_u2_out)

        b3_out = b3_out.view(b3_out.size(0), -1)
        output = self.fc(b3_out)
        return output





if __name__ == '__main__':
    data = torch.randn(1, 1, 8, 128).to('cuda')
    model = EEGInception().to('cuda')
    output = model(data)
    sum_parameter = 0
    for param in model.parameters():
        sum_parameter += param.numel()
    print(sum_parameter)
    summary(model, (1, 8, 128), device='cuda', batch_size=48)

讲到与原论文不同的是三个block的卷积核都不一样,因为我们输入的shape和原论文不一样,我们的输入的shape为标准的(batch_size, 1, channel, sample),是常用的EEGNet论文的输入,原论文输入请翻阅论文,可将代码与原论文进行对比。

上面是原论文中各层的详细说明,原论文中的参数是15154,本篇文章的代码参数21484,因为block2中的通道数比原论文中较多的原因导致,大家可以参考代码继续修改模型。

三、总结

因时间原因,本模型只用CompetitionIVdataset2a数据集中的A01受试者进行测试,大家可以多用几个受试者对模型进行测试,并进行将近4倍的数据增强,便于模型捕获特征,进行了5-fold cross-validation,训练集与测试集划分为8:2,ave_acc在90%左右。如果本篇文章有任何错误之处,欢迎评论区指出。

在PyCharm中使用PyTorch实现EEGNet网络,你可以按照以下步骤进行: 1. 首先,确保你已经安装了PyTorch库。可以使用以下命令在PyCharm的终端中安装PyTorch: ``` pip install torch torchvision ``` 2. 创建一个新的Python文件,并导入所需的库: ```python import torch import torch.nn as nn import torch.optim as optim ``` 3. 定义EEGNet网络的模型类。EEGNet是一种用于处理脑电图(EEG)信号的轻量级卷积神经网络。以下是一个简单的EEGNet实现示例: ```python class EEGNet(nn.Module): def __init__(self, num_classes): super(EEGNet, self).__init__() self.firstConv = nn.Sequential( nn.Conv2d(1, 16, kernel_size=(1, 51), stride=(1, 1), padding=(0, 25), bias=False), nn.BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True), nn.ELU(), nn.MaxPool2d(kernel_size=(1, 4), stride=(1, 4), padding=0), nn.Dropout(p=0.25) ) self.depthwiseConv = nn.Sequential( nn.Conv2d(16, 32, kernel_size=(2, 1), stride=(1, 1), groups=16, bias=False), nn.BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True), nn.ELU(), nn.AvgPool2d(kernel_size=(1, 4), stride=(1, 4), padding=0), nn.Dropout(p=0.25) ) self.separableConv = nn.Sequential( nn.Conv2d(32, 32, kernel_size=(1, 15), stride=(1, 1), padding=(0, 7), bias=False), nn.BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True), nn.ELU(), nn.AvgPool2d(kernel_size=(1, 8), stride=(1, 8), padding=0), nn.Dropout(p=0.25) ) self.classifier = nn.Linear(736, num_classes) def forward(self, x): x = self.firstConv(x) x = self.depthwiseConv(x) x = self.separableConv(x) x = x.view(x.size(0), -1) x = self.classifier(x) return x ``` 4. 创建一个实例化的EEGNet模型,并定义损失函数和优化器: ```python model = EEGNet(num_classes=2) # 替换num_classes为你的类别数目 criterion = nn.CrossEntropyLoss() optimizer = optim.Adam(model.parameters(), lr=0.001) ``` 5. 准备你的数据,并进行训练和测试循环: ```python # 假设你的训练数据为train_loader,测试数据为test_loader for epoch in range(num_epochs): model.train() for inputs, labels in train_loader: optimizer.zero_grad() outputs = model(inputs) loss = criterion(outputs, labels) loss.backward() optimizer.step() model.eval() with torch.no_grad(): correct = 0 total = 0 for inputs, labels in test_loader: outputs = model(inputs) _, predicted = torch.max(outputs.data, 1) total += labels.size(0) correct += (predicted == labels).sum().item() accuracy = correct / total print(f"Epoch {epoch+1}/{num_epochs}, Test Accuracy: {accuracy}") ``` 这就是在PyCharm中使用PyTorch实现EEGNet网络的基本步骤。你可以根据自己的需求进行修改和扩展。记得提前准备好你的数据集和加载器。祝你成功实现!
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值