pytorch代码复现:MIEEG-GAN——生成人工运动想象脑电图信号【1】

2024/11/29-2024/12/2:
简单阅读一下论文与代码。想尽量在这几天完成代码的复现工作,不然一直拖着太慢了。
这次文献的阅读主要解决我一直以来的一个疑问:如何判断生成的EEG信号是“好”的?
论文地址MIEEG-GAN: Generating Artificial Motor Imagery Electroencephalography
Signals

论文作者github地址Generative-networks-EEG
博主复现项目gitee地址基于GAN的运动想象EEG信号生成
reference:
[1]BCI competition IV 2b数据集简单处理与批量读取
[2]GAN原理讲解与PyTorch手写逐行讲解

一、实验设置

1.1 数据集相关

该论文采用数据集为BCI competition IV 2b,相关信息可见于reference[1],该数据集是一个二分类的MI数据集,较为简单。
实验设置
在该数据集中,运动想象信号的窗口是3-7s处。数据原本有3个EEG通道,而论文只用了C3这个通道,估计是为了方便处理以及控制训练GAN的时长,但是个人觉得一起生成3个通道相比于只做一个其实没差很多。

在样本划分方面将这共4s的数据再次划分窗口,窗口大小为2s,窗口移动步长为200ms,从每一段4s的数据中得到了11个可用于训练的样本,每个样本长度为2s(在250hz下是500个采样点)。

而这些样本随后又会被用于构建256ms子窗口(64个采样点),step为56ms(换言之相邻窗口重叠50个采样点)的STFT图像,快速傅里叶变换(FFT)点的数量为512。因此频谱图的大小为257×32,其中257是频率分量的数量,32是时间点的数量。从该频谱图中,选择 13-32 Hz 的 β 频谱图,其大小为 41×32。这些图像最后被用于验证生成EEG信号的可行性。

这里比较奇怪的是,在每个运动想象的4s时间中,其中最后1s(6-7s)是休息时间,为什么要把这段时间包括在内?比较可信的原因应该还是为了尽可能扩充训练样本的数量。

1.2 实验架构

GAN架构
对于GAN的架构没什么好说的,就是一个最经典的GAN训练过程而已。生成器接收噪声生成信号,然后判别器来将生成的样本和真实样本尽力分别开。这里就不过多赘述了。
关于GAN的相关细节可参考reference[3],这里不过多赘述。

二、GAN网络具体架构

论文中的GAN架构也并不复杂。主要用双向lstm提取了特征。作者给的github仓库看起来挺杂乱的,里面还塞了学习用的mnist相关代码,且神经网络用的是tensorflow实现的,看了半天完全没有和论文相关的东西,令人疑惑…
所以我们自己用pytorch复现一下:

2.1 生成器设置

生成器
这个网络架构很简单,我就直接贴上代码了:

class Generator(nn.Module):
    def __init__(self, noise_dim=50, seq_length=500):
        super(Generator, self).__init__()
        self.seq_length = seq_length
        self.noise_dim = noise_dim
        self.lstm1 = nn.LSTM(input_size=noise_dim, hidden_size=30, batch_first=True, bidirectional=True)
        self.lstm2 = nn.LSTM(input_size=60, hidden_size=30, batch_first=True, bidirectional=True)
        self.dropout = nn.Dropout(0.2)
        self.dense = nn.Linear(60, 1)
        self.activation = nn.Tanh()

    def forward(self, x):
        x, _ = self.lstm1(x)
        x = self.activation(x)
        x, _ = self.lstm2(x)
        x = self.activation(x)
        x = self.dropout(x)
        x = self.dense(x)
        x = self.activation(x)
        return x

2.2 判别器设置

判别器

class Discriminator(nn.Module):
    def __init__(self, seq_length=500):
        super(Discriminator, self).__init__()
        self.seq_length = seq_length
        self.lstm1 = nn.LSTM(input_size=1, hidden_size=30, batch_first=True, bidirectional=True)
        self.lstm2 = nn.LSTM(input_size=60, hidden_size=30, batch_first=True, bidirectional=True)
        self.dropout = nn.Dropout(0.2)
        self.flatten = nn.Flatten()
        self.dense = nn.Linear(60 * seq_length, 1)
        self.sigmoid = nn.Sigmoid()
        self.tanh = nn.Tanh()

    def forward(self, x):
        x, _ = self.lstm1(x)
        x = self.tanh(x)
        x, _ = self.lstm2(x)
        x = self.tanh(x)
        x = self.dropout(x)
        x = self.flatten(x)
        x = self.dense(x)
        x = self.sigmoid(x)
        return x

判别器也很简单,但是图里判别器的输入和生成器输出的shape不一致,一开始照这个写害我出了个bug,所以生成的fake_data输入判别器时还得unsqueeze一下。

2.3 模型运行

对于GAN的具体训练这里略去不谈,请看reference[2],我也是从那里学的。这里贴一下相关代码,GAN的训练主要是生成器和判别器拉扯的过程:

def model_train(dataloader, device, epochs=150, model_save=False,
                             model_save_path="./model_param/MIEEGModel"):
    # 创建模型和优化器实例
    generator = Generator().to(device)
    discriminator = Discriminator().to(device)
    # 生成器优化器,目的是减小生成器生成的数据与真实数据的差距
    optimizer_G = optim.Adam(generator.parameters(), lr=0.0001)
    # 判别器优化器,目的是增大判别器判别生成器生成的数据与真实数据的差距
    optimizer_D = optim.Adam(discriminator.parameters(), lr=0.0001)
    # 二分类交叉熵损失函数
    criterion = nn.BCELoss()

    for epoch in range(epochs):
        for i, data in enumerate(dataloader):
            real_data = data[0].to(device)
            # print(real_data.shape)
            # 因为是单通道数据,所以需要去掉通道维度
            # real_data = real_data.squeeze(1)
            bs = real_data.size(0)
            # 生成噪声输入生成器
            noise = torch.randn(bs, 500, 50).to(device)

            # 训练生成器
            # [bs, 500, 50] -> [bs, 500, 1]
            fake_data = generator(noise)

            optimizer_G.zero_grad()
            # 计算生成器的损失
            g_loss = criterion(discriminator(fake_data), torch.ones(bs, 1).to(device))
            g_loss.backward()
            optimizer_G.step()

            # 计算判别器的损失
            optimizer_D.zero_grad()
            real_loss = criterion(discriminator(real_data), torch.ones(bs, 1).to(device))
            fake_loss = criterion(discriminator(fake_data.detach()), torch.zeros(bs, 1).to(device))
            d_loss = real_loss + fake_loss

            # 观察real_loss与fake_loss,同时下降同时达到最小值,并且差不多大,说明D已经稳定了
            d_loss.backward()
            optimizer_D.step()

            if i % 200 == 0:
                print(f"Epoch {epoch}, Iteration {i}, D Loss: {d_loss.item()}, G Loss: {g_loss.item()}")

    if model_save:
        # 路径 + 时间戳
        torch.save(generator.state_dict(), model_save_path + "/generator_" + str(int(time.time())) + ".pt")

我运行了一下没报错,但是现在也仅能保证不报错。

写在最后

关于论文中的验证fake_date手法以及后续可能的一些改进手段我们之后再说,这个星期的任务就先算完成了。不得不说,之前的复现都是自己改下作者的代码改到能跑,这次真正复现的收获还是挺大的。

2024/12/8:
今天对之前训练好的模型生成了一下数据,结果不尽如人意,可能在模型的参数方面或者数据的预处理方面有所遗漏。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值