2024/11/29-2024/12/2:
简单阅读一下论文与代码。想尽量在这几天完成代码的复现工作,不然一直拖着太慢了。
这次文献的阅读主要解决我一直以来的一个疑问:如何判断生成的EEG信号是“好”的?
论文地址:MIEEG-GAN: Generating Artificial Motor Imagery Electroencephalography
Signals
论文作者github地址:Generative-networks-EEG
博主复现项目gitee地址:基于GAN的运动想象EEG信号生成
reference:
[1]BCI competition IV 2b数据集简单处理与批量读取
[2]GAN原理讲解与PyTorch手写逐行讲解
一、实验设置
1.1 数据集相关
该论文采用数据集为BCI competition IV 2b,相关信息可见于reference[1],该数据集是一个二分类的MI数据集,较为简单。
在该数据集中,运动想象信号的窗口是3-7s处。数据原本有3个EEG通道,而论文只用了C3这个通道,估计是为了方便处理以及控制训练GAN的时长,但是个人觉得一起生成3个通道相比于只做一个其实没差很多。
在样本划分方面将这共4s的数据再次划分窗口,窗口大小为2s,窗口移动步长为200ms,从每一段4s的数据中得到了11个可用于训练的样本,每个样本长度为2s(在250hz下是500个采样点)。
而这些样本随后又会被用于构建256ms子窗口(64个采样点),step为56ms(换言之相邻窗口重叠50个采样点)的STFT图像,快速傅里叶变换(FFT)点的数量为512。因此频谱图的大小为257×32,其中257是频率分量的数量,32是时间点的数量。从该频谱图中,选择 13-32 Hz 的 β 频谱图,其大小为 41×32。这些图像最后被用于验证生成EEG信号的可行性。
这里比较奇怪的是,在每个运动想象的4s时间中,其中最后1s(6-7s)是休息时间,为什么要把这段时间包括在内?比较可信的原因应该还是为了尽可能扩充训练样本的数量。
1.2 实验架构
对于GAN的架构没什么好说的,就是一个最经典的GAN训练过程而已。生成器接收噪声生成信号,然后判别器来将生成的样本和真实样本尽力分别开。这里就不过多赘述了。
关于GAN的相关细节可参考reference[3],这里不过多赘述。
二、GAN网络具体架构
论文中的GAN架构也并不复杂。主要用双向lstm提取了特征。作者给的github仓库看起来挺杂乱的,里面还塞了学习用的mnist相关代码,且神经网络用的是tensorflow实现的,看了半天完全没有和论文相关的东西,令人疑惑…
所以我们自己用pytorch复现一下:
2.1 生成器设置
这个网络架构很简单,我就直接贴上代码了:
class Generator(nn.Module):
def __init__(self, noise_dim=50, seq_length=500):
super(Generator, self).__init__()
self.seq_length = seq_length
self.noise_dim = noise_dim
self.lstm1 = nn.LSTM(input_size=noise_dim, hidden_size=30, batch_first=True, bidirectional=True)
self.lstm2 = nn.LSTM(input_size=60, hidden_size=30, batch_first=True, bidirectional=True)
self.dropout = nn.Dropout(0.2)
self.dense = nn.Linear(60, 1)
self.activation = nn.Tanh()
def forward(self, x):
x, _ = self.lstm1(x)
x = self.activation(x)
x, _ = self.lstm2(x)
x = self.activation(x)
x = self.dropout(x)
x = self.dense(x)
x = self.activation(x)
return x
2.2 判别器设置
class Discriminator(nn.Module):
def __init__(self, seq_length=500):
super(Discriminator, self).__init__()
self.seq_length = seq_length
self.lstm1 = nn.LSTM(input_size=1, hidden_size=30, batch_first=True, bidirectional=True)
self.lstm2 = nn.LSTM(input_size=60, hidden_size=30, batch_first=True, bidirectional=True)
self.dropout = nn.Dropout(0.2)
self.flatten = nn.Flatten()
self.dense = nn.Linear(60 * seq_length, 1)
self.sigmoid = nn.Sigmoid()
self.tanh = nn.Tanh()
def forward(self, x):
x, _ = self.lstm1(x)
x = self.tanh(x)
x, _ = self.lstm2(x)
x = self.tanh(x)
x = self.dropout(x)
x = self.flatten(x)
x = self.dense(x)
x = self.sigmoid(x)
return x
判别器也很简单,但是图里判别器的输入和生成器输出的shape不一致,一开始照这个写害我出了个bug,所以生成的fake_data输入判别器时还得unsqueeze一下。
2.3 模型运行
对于GAN的具体训练这里略去不谈,请看reference[2],我也是从那里学的。这里贴一下相关代码,GAN的训练主要是生成器和判别器拉扯的过程:
def model_train(dataloader, device, epochs=150, model_save=False,
model_save_path="./model_param/MIEEGModel"):
# 创建模型和优化器实例
generator = Generator().to(device)
discriminator = Discriminator().to(device)
# 生成器优化器,目的是减小生成器生成的数据与真实数据的差距
optimizer_G = optim.Adam(generator.parameters(), lr=0.0001)
# 判别器优化器,目的是增大判别器判别生成器生成的数据与真实数据的差距
optimizer_D = optim.Adam(discriminator.parameters(), lr=0.0001)
# 二分类交叉熵损失函数
criterion = nn.BCELoss()
for epoch in range(epochs):
for i, data in enumerate(dataloader):
real_data = data[0].to(device)
# print(real_data.shape)
# 因为是单通道数据,所以需要去掉通道维度
# real_data = real_data.squeeze(1)
bs = real_data.size(0)
# 生成噪声输入生成器
noise = torch.randn(bs, 500, 50).to(device)
# 训练生成器
# [bs, 500, 50] -> [bs, 500, 1]
fake_data = generator(noise)
optimizer_G.zero_grad()
# 计算生成器的损失
g_loss = criterion(discriminator(fake_data), torch.ones(bs, 1).to(device))
g_loss.backward()
optimizer_G.step()
# 计算判别器的损失
optimizer_D.zero_grad()
real_loss = criterion(discriminator(real_data), torch.ones(bs, 1).to(device))
fake_loss = criterion(discriminator(fake_data.detach()), torch.zeros(bs, 1).to(device))
d_loss = real_loss + fake_loss
# 观察real_loss与fake_loss,同时下降同时达到最小值,并且差不多大,说明D已经稳定了
d_loss.backward()
optimizer_D.step()
if i % 200 == 0:
print(f"Epoch {epoch}, Iteration {i}, D Loss: {d_loss.item()}, G Loss: {g_loss.item()}")
if model_save:
# 路径 + 时间戳
torch.save(generator.state_dict(), model_save_path + "/generator_" + str(int(time.time())) + ".pt")
我运行了一下没报错,但是现在也仅能保证不报错。
写在最后
关于论文中的验证fake_date手法以及后续可能的一些改进手段我们之后再说,这个星期的任务就先算完成了。不得不说,之前的复现都是自己改下作者的代码改到能跑,这次真正复现的收获还是挺大的。
2024/12/8:
今天对之前训练好的模型生成了一下数据,结果不尽如人意,可能在模型的参数方面或者数据的预处理方面有所遗漏。