基于机器学习的心律失常分类(六)——BP神经网络分类[MATLAB]
https://zhuanlan.zhihu.com/p/162804227
说明:
1、DATA矩阵为arma系数矩阵,1-6列为ARMA模型系数,第七列为所数类别(1-正常心电,2-左束支阻滞,3-右束支阻滞,4-室性早搏)
2、四个类别分别选取各1700个数据,共计6800个样本,取70%作为训练集。
%% I. 导入数据
matrix=DATA(:,1:6);
label=DATA(:,7);
%%
% 1. 随机产生训练集和测试集
n1 = randperm(1700);
n2 = randperm(1700);
n3 = randperm(1700);
n4 = randperm(1700);
%%
% 2. 训练集——4760(6800)个样本
P_train1 = matrix(n1(1:1190)😅;
P_train2 = matrix(1700+n2(1:1190)😅;
P_train3 = matrix(17002+n3(1:1190)😅;
P_train4 = matrix(17003+n4(1:1190)😅;
X_train = [P_train1;P_train2;P_train3;P_train4];
T_train1 = label(n1(1:1190)😅;
T_train2 = label(1700+n2(1:1190)😅;
T_train3 = label(17002+n3(1:1190)😅;
T_train4 = label(17003+n4(1:1190)😅;
y_train = [T_train1;T_train2;T_train3;T_train4];
%%
% 3. 测试集——20个样本
P_test1 = matrix(n1(1191:end)😅;
P_test2 = matrix(1700+n2(1191:end)😅;
P_test3 = matrix(17002+n3(1191:end)😅;
P_test4 = matrix(17003+n4(1191:end)😅;
X_test = [P_test1;P_test2;P_test3;P_test4];
T_test1 = label(n1(1191:end)😅;
T_test2 = label(1700+n2(1191:end)😅;
T_test3 = label(17002+n3(1191:end)😅;
T_test4 = label(17003+n4(1191:end)😅;
y_test = [T_test1;T_test2;T_test3;T_test4];
%% II. 数据归一化
[X_train,minI,maxI] = premnmx( X_train’) ;
X_train=X_train’;
%% III. 神经网络
% 1. 构造输出矩阵
s=length(y_train);
output=zeros(s,4); %分4类 [0 0 0 0]
for i=1:s
output(i,y_train(i))=1;
end
%%
% 2. 创建神经网络
net=newff(minmax(X_train’),[11 4],{‘tansig’ ‘tansig’},‘traingdx’);
%%
% 3. 设置训练参数
net.trainparam.show = 50 ; % 每间隔500步显示一次训练结果
net.trainparam.epochs = 100000 ; %允许最大训练步数
net.trainparam.goal = 0.001 ; %训练目标最小误差0.01
net.trainParam.lr = 0.05 ; %学习速率0.05
%%
% 4. 开始训练
net=train(net,X_train’,output’);
%% IV. 仿真
X_test = tramnmx ( X_test’ , minI, maxI ) ;
X_test=X_test’;
Y=sim(net,X_test’);
%% V. 统计识别正确率
[s1,s2]=size(Y);
T_sim=[];
hitNum=0;
for i=1:s2
[m,Index]=max(Y(:,i));
T_sim(i)=Index;
if(Index==y_test(i))
hitNum=hitNum+1;
end
end
sprintf(‘识别率是 %3.3f%%’,100 * hitNum / s2 )
T_sim=T_sim’;
T_test=y_test;
%% 混淆矩阵
H=[];
number_1_1= length(find(T_sim == 1 & T_test == 1));H(1,1)=number_1_1;
number_1_2= length(find(T_sim == 1 & T_test == 2));H(1,2)=number_1_2;
number_1_3= length(find(T_sim == 1 & T_test == 3));H(1,3)=number_1_3;
number_1_4= length(find(T_sim == 1 & T_test == 4));H(1,4)=number_1_4;
number_2_1= length(find(T_sim == 2 & T_test == 1));H(2,1)=number_2_1;
number_2_2= length(find(T_sim == 2 & T_test == 2));H(2,2)=number_2_2;
number_2_3= length(find(T_sim == 2 & T_test == 3));H(2,3)=number_2_3;
number_2_4= length(find(T_sim == 2 & T_test == 4));H(2,4)=number_2_4;
number_3_1= length(find(T_sim == 3 & T_test == 1));H(3,1)=number_3_1;
number_3_2= length(find(T_sim == 3 & T_test == 2));H(3,2)=number_3_2;
number_3_3= length(find(T_sim == 3 & T_test == 3));H(3,3)=number_3_3;
number_3_4= length(find(T_sim == 3 & T_test == 4));H(3,4)=number_3_4;
number_4_1= length(find(T_sim == 4 & T_test == 1));H(4,1)=number_4_1;
number_4_2= length(find(T_sim == 4 & T_test == 2));H(4,2)=number_4_2;
number_4_3= length(find(T_sim == 4 & T_test == 3));H(4,3)=number_4_3;
number_4_4= length(find(T_sim == 4 & T_test == 4));H(4,4)=number_4_4;
H