基于小波分析与深度学习的脑电信号分类(matlab)

原理

通过小波变换对运动想象信号进行特征提取,生成时频图像作为神经网络的输入。

实现

使用BCI竞赛2008–Graz dataset A中的A01受试者的数据作为数据集。

采样

采样频率250Hz,每个数据前三个为伪迹参考信号,后6个为EEG信号集,
每一个划成48次,四个任务,每个任务12次;
每次任务大概8s,每次大概从3-6s(750-1500点)为运动想像时间,
拟采集770-1462点,每个样本512个点,采样间隔20个点,采集25个样本;

%% 采样信号
clear;
fs = 250;     % 采样频率 250Hz
Na = 96735;   %
Nt = 48;      % 一个EEG信号集划成48次数据,即四个任务、每个任务12次
Ns = 25;      % 样本数 25个
Np = 20;      % 采样间隔20个点
N  = 256;
%%
x00 = load('A01T');
%%
for k=1:6  % 后6个为EEG信号集,即data {1,4} {1,5} {1,6} {1,7} {1,8} {1,9}
    x01 = x00.data{1, k+3}.X;    % EEG信号
    y01 = x00.data{1, k+3}.y;    % 类别
    t = x00.data{1, k+3}.trial;  % 试验(trials),包含伪迹
    t(Nt+1) = Na;
    %figure
    for i = 1:Nt
        x0 = x01(t(i):t(i+1), :);
        %subplot(6,8,i);
        %plot(x0(:,1));xlim([0 2100]);ylim([-100 100]);
        for j = 1:Ns
            x1 = x0(750+Np*(j-1):750+Np*(j-1)+N-1, 1:22);
            x2 = (x1-min(x1(:)))/(max(x1(:))-min(x1(:)));  % 最大最小归一化
            XTr(:, :, 1, 1200*(k-1)+25*(i-1)+j) = x2;
            YTr(1, 1200*(k-1)+25*(i-1)+j) = categorical(y01(i));
        end
        clear x0; % 每次迭代x0的长度会发生变化
    end
end

save SubA_Train XTr YTr;

请添加图片描述

小波变换

%% 小波变换
clear
load SubA_Train;
%%

id=[8 10 12];  % 选三个电极,
parfor i=1:length(XTr)
    for j=1:3 
        x = XTr(:,id(j),1,i);
        x1 = abs(cwt(x));  % 小波变换
        XTrft(:,:,j,i) = (x1-min(x1(:)))/(max(x1(:))-min(x1(:)));   % 归一化
    end     
end

save SubA_TF_Train XTrft YTr;

%% 可视化一个样本为彩色图片
size(XTrft(:,:,:,1))  % 51×256×3
categories(YTr)  % 查看类别数

figure;
imshow(XTrft(:,:,:,1))

时频图样张:
请添加图片描述

把时频图保存到本地文件夹

  • 图片的尺寸为51×256×3
  • 4个文件夹(0、1、2、3),每个文件夹的图片为同一类信号
%% 转成图片格式,先新建一个images文件夹,然后在images里面新建4个文件夹,分别为0、1、2、3.
load SubA_TF_Train
for i = 1:7200
    k = double(string(YTr(1,i)))-1;  % label
    imwrite(XTrft(:,:,:,i),['images\',num2str(k)','\',num2str(i),'.jpg'])  % 保存为图片
end

训练和评估

利用deepNetworkDesigner搭建网络,导出到工作区,训练。需要注意的是,网络的输出层为4类。可以采用典型的网络,例如Googlenet、resnet等。

clear;

%% 导入数据集
imdsTrain = imageDatastore("images","IncludeSubfolders",true,"LabelSource","foldernames");
[imdsTrain, imdsValidation] = splitEachLabel(imdsTrain,0.8,"randomized");

% 调整图像大小以匹配网络输入层
% inputsize = [256 256 3];
inputsize = [51 256 3];
augimdsTrain = augmentedImageDatastore(inputsize,imdsTrain);
augimdsValidation = augmentedImageDatastore(inputsize,imdsValidation);

%% 网络结构alexnet
% Net = alexnet;
% Net = googlenet;
% Net = inceptionresnetv2;
deepNetworkDesigner

%% 训练网络
miniBatchSize = 128;
learnRate = 0.0001;
valFrequency = floor(0.8*7200.0/miniBatchSize);
options = trainingOptions('adam', ...
    'InitialLearnRate',learnRate, ...
    'MaxEpochs',20, ...
    'MiniBatchSize',miniBatchSize, ...
    'Shuffle','every-epoch', ...
    'Plots','training-progress', ...
    'Verbose',false, ...
    'ValidationData',augimdsValidation, ...
    'ValidationFrequency',valFrequency, ...
    'LearnRateSchedule','piecewise', ...
    'LearnRateDropFactor',0.1, ...
    'LearnRateDropPeriod',5);
trainedNet = trainNetwork(augimdsTrain, lgraph_1, options);

%% 评估
% 准确率
% 训练集
[YPred,probs] = classify(trainedNet,augimdsTrain);
accuracy = mean(YPred == imdsTrain.Labels)
disp("training acc: " + accuracy*100 + "%")
% 验证集
[YPred,probs] = classify(trainedNet,augimdsValidation);
accuracy = mean(YPred == imdsValidation.Labels)
disp("val acc: " + accuracy*100 + "%")

% 混淆矩阵
figure('Units','normalized','Position',[0.2 0.2 0.4 0.4]);
cm = confusionchart(imdsValidation.Labels,YPred);
cm.Title = 'Confusion Matrix for Validation Data';
cm.ColumnSummary = 'column-normalized';
cm.RowSummary = 'row-normalized';

结果

  • alexnet
    • training acc: 99.3924%
    • val acc: 84.7917%
  • resnet
    • training acc: 100%
    • val acc: 85.4861%
  • simplenet(我自己搭建的网络)
    • training acc: 98.7674%
    • val acc: 90.0694%

请添加图片描述请添加图片描述

python版本的数据集

  • https://github.com/bregydoc/bcidatasetIV2a
  • 1
    点赞
  • 77
    收藏
    觉得还不错? 一键收藏
  • 7
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 7
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值