信号处理--基于混合CNN和transfomer自注意力的多通道脑电信号的情绪分类的简单应用

本文探讨了一种结合卷积神经网络和Transformer的混合模型在基于EEG的情绪分类任务中的应用。实验比较了逻辑回归、支持向量机、随机森林和多层感知机等多种传统机器学习模型,以及深度学习的多层感知机和混合模型在情绪识别上的性能。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

目录

关于

工具

数据集

数据集简述

方法实现

数据读取

​编辑数据预处理

传统机器学习模型(逻辑回归,支持向量机,随机森林)

多层感知机模型

CNN+transfomer模型

代码获取


关于

  • 本实验利用结合了卷积神经网络 (CNN) 和 Transformer 组件的混合架构,实现基于 EEG 的有效情绪分类。
  • 尝试各种机器学习模型,包括逻辑回归、支持向量机 (SVM)、随机森林分类器和多层感知器 (MLP) 神经网络,以比较不同模型的性能。 

 图片来自于: https://ieeexplore.ieee.org/stamp/stamp.jsp?arnumber=9991178

工具

数据集

数据集简述

脑电图数据是从两名受试者(1 名男性、1 名女性,年龄 20-22 岁)收集的,针对特定电影剪辑引发的六种情绪状态(积极、消极、中性)中的每一种状态。该数据集包括从脑电波中收集的 324,000 个数据点,这些数据点被重新采样到 150Hz。还收集了中性脑电波数据,作为代表受试者静息情绪状态的第三类数据。从四个电极(TP9、AF7、AF8、TP10)记录 EEG 数据,并进行处理以生成通过 1 秒滑动窗口提取的统计特征数据集。

图片来源:https://www.researchgate.net/figure/This-figure-shows-the-standard-locations-for-measuring-EEG-as-per-10-20-International_fig2_358644174 

方法实现

数据读取
raw_eeg_data = pd.read_csv('../data/features_raw.csv')
raw_eeg_data.head()

# plot the F8 column
plt.figure(figsize=(20, 5))
plt.plot(raw_eeg_data['F8'])
plt.title('F8 Electrode Data')
plt.ylabel('Voltage (uV)')
plt.xlabel('Time')
plt.show()


# plot the F7 column
plt.figure(figsize=(20, 5))
plt.plot(raw_eeg_data['F7'])
plt.title('F7 Electrode Data')
plt.ylabel('Voltage (uV)')
plt.xlabel('Time')
plt.show()
数据预处理
X = eeg_emotions_data.drop(['label'], axis=1)
y = eeg_emotions_data['label']

# Encoding categorical data
from sklearn.preprocessing import LabelEncoder, OneHotEncoder

labelencoder_emotions = LabelEncoder()
y = labelencoder_emotions.fit_transform(y)


# Standardizing the features in the dataset
from sklearn.preprocessing import StandardScaler
scaler = StandardScaler()

X = scaler.fit_transform(X)

传统机器学习模型(逻辑回归,支持向量机,随机森林)
from sklearn.linear_model import LogisticRegression
import pickle

# Create a logistic regression classifier
model = LogisticRegression(random_state=2003, multi_class='multinomial', max_iter=1000)

# Train the model
model.fit(X_train, y_train)

# Evaluate the model
evaluate_model(y_test, model.predict(X_test))

from sklearn.svm import SVC

# Create a model: a support vector classifier
model = SVC(kernel='rbf', gamma='auto', C=1.0, random_state=2003)

# Train the model
model.fit(X_train, y_train)

# Evaluate the model
evaluate_model(y_test, model.predict(X_test))

from sklearn.ensemble import RandomForestClassifier

# Create a random forest Classifier.
model = RandomForestClassifier(n_estimators=100, random_state=2003)

# Train the model
model.fit(X_train, y_train)

# Evaluate the model
evaluate_model(y_test, model.predict(X_test))

在传统机器模型,我们可以发现随机森林的性能表现最好; 

多层感知机模型
class EEGClassifier(nn.Module):
    def __init__(self, input_dim, num_classes, hidden_dim=256):
        super(EEGClassifier, self).__init__()

        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(hidden_dim, num_classes)

    def forward(self, x):
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        return x


input_dim = 2548  # Number of features in EEG signal
num_classes = 3   # Number of classes for classification
model = EEGClassifier(input_dim, num_classes)
loss = nn.CrossEntropyLoss()

CNN+transfomer模型
class EEGConformer(nn.Module):
    def __init__(self, input_dim, num_classes):
        super(EEGConformer, self).__init__()

        # CNN
        self.conv1 = nn.Conv2d(1, 40, kernel_size=(1, 25), stride=(1, 1))
        self.conv2 = nn.Conv2d(40, 40, kernel_size=(1, input_dim), stride=(1, 1))
        self.batchnorm = nn.BatchNorm2d(40)

        # Transformer
        self.layernorm1 = nn.LayerNorm(40)
        self.multiheadattention = nn.MultiheadAttention(40, 1)
        self.layernorm2 = nn.LayerNorm(40)

        self.feedworward_block = nn.Sequential(
            nn.Linear(40, 32),
            nn.GELU(),
            nn.Dropout(p=0.1),
            nn.Linear(32, 40)
        )

        # MLP
        self.fc1 = nn.Linear(40, 32)
        self.fc2 = nn.Linear(32, 32)
        self.fc3 = nn.Linear(32, num_classes)

    def forward(self, x):
        # CNN
        x = x.unsqueeze(1).unsqueeze(1)
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.batchnorm(x)

        # Transformer
        x = x.squeeze()
        x = self.layernorm1(x)
        attn_out = self.multiheadattention(x, x, x)
        x = x + nn.Dropout(0.1)(attn_out[0])
        x = self.layernorm2(x)
        x = self.feedworward_block(x)
        x = nn.Dropout(p=0.1)(x)

        # MLP
        x = self.fc1(x)
        x = F.elu(x)
        x = nn.Dropout(p=0.5)(x)
        x = self.fc2(x)
        x = F.elu(x)
        x = nn.Dropout(p=0.3)(x)
        x = self.fc3(x)
        
        return x

input_dim = 2524  # Number of features in EEG signal
num_classes = 3   # Number of classes for classification
model = EEGConformer(input_dim, num_classes)
loss = nn.CrossEntropyLoss()

代码获取

后台私信,注明来意和文章名称;

其他问题,欢迎沟通交流。

### 解决PyCharm无法加载Conda虚拟环境的方法 #### 配置设置 为了使 PyCharm 能够成功识别并使用 Conda 创建的虚拟环境,需确保 Anaconda 的路径已正确添加至系统的环境变量中[^1]。这一步骤至关重要,因为只有当 Python 解释器及其关联工具被加入 PATH 后,IDE 才能顺利找到它们。 对于 Windows 用户而言,在安装 Anaconda 时,默认情况下会询问是否将它添加到系统路径里;如果当时选择了否,则现在应该手动完成此操作。具体做法是在“高级系统设置”的“环境变量”选项内编辑 `Path` 变量,追加 Anaconda 安装目录下的 Scripts 文件夹位置。 另外,建议每次新建项目前都通过命令行先激活目标 conda env: ```bash conda activate myenvname ``` 接着再启动 IDE 进入工作区,这样有助于减少兼容性方面的问题发生概率。 #### 常见错误及修复方法 ##### 错误一:未发现任何解释器 症状表现为打开 PyCharm 新建工程向导页面找不到由 Conda 构建出来的 interpreter 列表项。此时应前往 Preferences/Settings -> Project:...->Python Interpreter 下方点击齿轮图标选择 Add...按钮来指定自定义的位置。按照提示浏览定位到对应版本 python.exe 的绝对地址即可解决问题。 ##### 错误二:权限不足导致 DLL 加载失败 有时即使指定了正确的解释器路径,仍可能遇到由于缺乏适当的操作系统级许可而引发的功能缺失现象。特别是涉及到调用某些特定类型的动态链接库 (Dynamic Link Library, .dll) 时尤为明显。因此拥有管理员身份执行相关动作显得尤为重要——无论是从终端还是图形界面触发创建新 venv 流程均如此处理能够有效规避此类隐患。 ##### 错误三:网络连接异常引起依赖下载超时 部分开发者反馈过因网速慢或者其他因素造成 pip install 操作中途断开进而影响整个项目的初始化进度条卡住的情况。对此可尝试调整镜像源加速获取速度或是离线模式预先准备好所需资源包后再继续后续步骤。 ---
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

医学信号图像玩家

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值