利用matlab deep learning toolbox 实现DNN网络训练

最近工作发现python的scipy.singal.stft函数的返回值第一列有问题,并且解决不了,因此换用matlab spectrogram()进行绘图,因此利用matlab搭建神经网络。研究了一天发现matlab 相比于 python 其网络的可视化效果更好,图表输出更美观。总结一下matlab deep learning toolbox的使用方式。

deep learning toolbox的使用方法和keras比较类似,训练模型可以归结为4步:定义训练数据→定义神经网络模型→配置学习过程→训练模型。(写代码和写文章的电脑不是同一台,下列代码可能存在拼写问题)

1、训练数据

我采用的训练数据是仿真信号y=cos(2piwt+φ),然后利用awgn()加了噪声。输入值是y的延迟信号中的一段,共2000个点,记为x;输出值是y自身的某一点信号;二者存在对应关系。假设y为[12000,1],则x的对应长度为[14000,1]。

在划分训练集这块,matlab的支持文档都是针对图像或者文件,利用函数建立一个生成器,把数据读进去,在toolbox目录下有helperModClassFrameGenerator、helperModClassFrameStore等作为参考,利用这些函数可以直接打乱数据集,划分训练、验证、测试集。由于我的数据比较简单,就没有用到这些。

% main
xTrain = cell(12000,1);
yTrain = cell(12000,1);
for ii = 1:12000
    xTrain{ii,1} = x(ii:ii+2000-1)
    yTrain{ii,1} = y(ii)
end

xTrain和yTrain是两个元胞数组,存储了训练数据,xTrain是网络输入,yTrain是输出,如果y是标签的话可以把它改成逻辑型。

2、定义神经网络模型

matlab的神经网络模型层搭建方式和keras有些类似,但比keras少了很多内容,新建层上面也会比keras难一些,不过查阅资料发现matlab可以直接读取python建立的模型,具体操作还未验证。

实验的网络很简单,是一个3层的全连接网络,可以表示为:输入→全连接→tanh→全连接→输出。代码实现方式如下:

function net=k_models(inputsize)

    net=[
        sequenceInputLayer(inputsize,'Normalization','none','Name','Input Layer')

        fullyConnectedLayer(50,'Name','fc1')
        tanhLayer('Name','tanh1')
        fullyConnectedLayer(1,'Name','fc2')

        regressionLayer('Name','Output Layer')
    ];

end

其中最后的输出层定义了loss function,整体网络也可以采用matlab deep network designer工具箱进行生成。

matlab 提供了一个函数可以可视化网络

% main
modnet = k_models(2000)
analyzeNetwork(modnet)

3、配置学习过程

% main
options = trainingOptions('sgdm','InitialLearnRate',Ir,'MaxEpochs',50,'MiniBatchSize',32, 'Shuffle','every-epoch','Plots','training-progress','Verbose',0,'LearnRateSchedule', 'piecewise','LearnRateDropFactor',0.1,'LearnRateDropPeriod',0.1,'ExecutionEnvironment','gpu')

trainingOptions中包含了所有训练信息,比如梯度下降方式,学习率,Epoch,BatchSize等等,好像是默认每Epoch是全体数据。

Verbose是是否将训练过程打印在命令行中,关掉可以提高程序运行的时间;Plots是是否可视化训练效果,可以打开看看,比tensorboard视图好看;ExecutionEnvironment是运行环境cpu或者gpu。其他参数可以在trainingOptions函数中查阅。

4、训练模型

训练模型非常简单,给输入,输出,模型,训练讯息即可。

% main
model = trainNetwork(xTrain,yTrain,modnet,options)

5、测试

测试也非常简单

% main
z = predict(model,xTest)

其中xTest是测试集,也是一个元胞数组;输出z和yTrain的格式相同,同为元胞数组。

 

  • 7
    点赞
  • 68
    收藏
    觉得还不错? 一键收藏
  • 3
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值