使用matlab构建深度神经网络:深度网络设计器、analyzeNetwork

深度网络设计器 (Deep Network Designer) 概述

深度网络设计器是 MATLAB 中的一个应用程序,允许您创建、编辑、可视化、分析和导出深度学习网络。通过这个工具,您可以:

  • 构建和编辑网络架构。
  • 加载和修改预训练网络,进行迁移学习。
  • 导入来自 PyTorch® 和 TensorFlow™ 的模型。
  • 分析网络的结构,确保其定义正确。
  • 使用压缩技术减少网络的内存占用。
  • 将网络导出到 Simulink® 进行进一步的模拟或嵌入。
  • 生成 MATLAB® 代码以便重新创建网络架构。

主要功能

  1. 构建、编辑和组合网络

    • 通过拖放方式创建和编辑神经网络,可以手动调整网络层或组合已有的预训练网络进行迁移学习。
  2. 导入预训练网络

    • 支持导入 PyTorch® 和 TensorFlow™ 模型,并允许对其进行调整和修改以适应新的任务。
  3. 迁移学习

    • 可以编辑最后的可学习层(如分类层)来适应新的数据集。例如,调整类别数量或修改学习率,使网络能够更好地适应新任务。
  4. 从外部平台导入网络

    • 深度网络设计器支持从外部平台(如 PyTorch 和 TensorFlow)导入网络,进行转换和修改后,可以直接用于 MATLAB 环境中的训练和推断。
  5. 压缩分析

    • 该工具能够分析网络并确定如何使用剪枝等技术来减少内存使用,并生成适当的压缩分析报告。
  6. Simulink 集成

    • 训练完成的网络可以直接导出到 Simulink® 进行模拟和部署。
  7. 生成 MATLAB 代码

    • 您可以生成 MATLAB 代码来重新创建网络架构,便于复现或进行迁移学习。

示例

1. 构建深度神经网络

通过深度网络设计器创建一个简单的分类网络:

deepNetworkDesigner

在 App 中,您可以创建一个空白网络,拖放层并配置其属性,例如输入层、全连接层、ReLU 层等,最后连接这些层以完成网络结构。

2. 准备预训练网络进行迁移学习

加载一个预训练的网络(如 SqueezeNet),然后调整网络的最后一层,以适应新的分类任务:

deepNetworkDesigner

选择 SqueezeNet,在最后的卷积层上进行修改,设置新的类数,并调整学习率。

3. 导入外部平台的网络

通过深度网络设计器,您可以导入 PyTorch 或 TensorFlow 的模型,例如从 PyTorch 导入 MNASNet:

modelfile = matlab.internal.examples.downloadSupportFile("nnet", "data/PyTorchModels/mnasnet1_0.pt");
deepNetworkDesigner

导入后,您可以进行修改并准备网络进行训练。

4. 将网络导出到 Simulink

在深度网络设计器中编辑并训练完网络后,您可以将其导出到 Simulink 进行进一步的模拟或集成。

deepNetworkDesigner

如何访问

  • 在 MATLAB 工具栏的 App 选项卡下,点击 深度学习与机器学习 下的 深度网络设计器 图标。
  • 在 MATLAB 命令窗口,输入 deepNetworkDesigner 启动应用。

版本历史

  • R2018b:深度网络设计器首次推出。
  • R2024b:增加了压缩分析功能,支持使用剪枝等技术减少内存使用量。
  • R2024a:默认使用 dlnetwork 对象,推荐替代旧版的 LayerGraph 对象。

总结

深度网络设计器是一个强大的工具,能够帮助用户快速构建、编辑、训练和分析深度学习模型。无论是从零开始设计网络,还是修改预训练模型进行迁移学习,它都提供了丰富的功能和直观的界面。此外,还支持从外部平台导入模型、导出到 Simulink 以及生成用于重建网络的 MATLAB 代码。

analyzeNetwork 函数概述

analyzeNetwork 是一个 MATLAB 函数,用于分析深度学习网络的架构。它可以帮助您可视化网络的结构、检查网络架构的正确性,并在训练之前发现潜在的问题。该函数可以检测网络中的问题,如不正确的层输入大小、输入层连接问题和不合法的网络结构等。

主要功能

  1. 网络架构分析

    • 通过交互式可视化展示网络架构。
    • 提供有关层类型、激活尺寸、可学习参数、状态和总的可学习参数的详细信息。
    • 检测网络架构中的错误和警告,例如不正确的输入输出连接、层尺寸不匹配等问题。
  2. 多种输入形式支持

    • 可以分析 dlnetwork 对象、层数组或 taylorPrunableNetwork 对象。
    • 如果网络有多个输入,您可以为每个输入提供示例数据,帮助分析器理解每层的激活尺寸和参数。
  3. 返回分析信息

    • 可以返回一个 NetworkAnalysis 对象,包含总的可学习参数数目、每层的详细信息以及潜在的问题。
  4. 定制化的分析输出

    • 您可以控制显示的图形类型,例如显示 “analyzer” 图表或不显示任何图表,直接访问结果。

语法

  • analyzeNetwork(net):分析给定的网络并显示交互式图表和网络层的详细信息。
  • analyzeNetwork(net, X1, ..., Xn):使用指定的示例输入来分析网络,适用于输入未直接连接到输入层的网络。
  • info = analyzeNetwork(___):返回 NetworkAnalysis 对象,您可以编程访问分析结果。
  • ___ = analyzeNetwork(___, Plots=plotName):控制是否显示分析图,plotName 可以是 "analyzer""none"

示例

示例 1:分析训练好的网络
net = imagePretrainedNetwork("googlenet");
info = analyzeNetwork(net);

这将加载一个预训练的 GoogLeNet 网络,并通过交互式图表和表格显示网络架构的详细信息,包括层类型、激活尺寸和可学习参数等。

示例 2:修复网络架构中的错误
net = dlnetwork;
layers = [
    imageInputLayer([32 32 3])
    convolution2dLayer(5,16,Padding="same")
    reluLayer(Name="relu_1")
    convolution2dLayer(3,16,Padding="same",Stride=2)
    reluLayer
    additionLayer(2,Name="add_1")
    convolution2dLayer(3,16,Padding="same",Stride=2)
    reluLayer
    additionLayer(3,Name="add_2")
    fullyConnectedLayer(10)
    softmaxLayer];

net = addLayers(net, layers);
analyzeNetwork(net);

此代码会显示关于网络架构的问题和警告,并指导您如何修复错误,如层输入尺寸不匹配、未连接的输入等。

示例 3:使用示例输入分析网络
X1 = dlarray(rand([64 64 3 32]), "SSCB");
X2 = dlarray(rand([32 32 18 32]), "SSCB");
analyzeNetwork(net, X1, X2);

在此示例中,您提供了示例输入数据 X1X2 来分析网络结构,适用于多输入的网络架构。

输出结果

  • TotalLearnables:网络的总可学习参数数目。
  • LayerInfo:包含每一层的详细信息,如层的类型、激活尺寸、学习参数等。
  • Issues:包含网络中发现的问题。
  • NumErrors:网络中的错误数量。
  • NumWarnings:网络中的警告数量。
  • AnalysisDate:分析的日期和时间。

版本历史

  • R2018a:首次推出 analyzeNetwork 函数。
  • R2024b:开始支持分析网络中的压缩相关信息。
  • R2024a:开始支持分析 taylorPrunableNetwork 和包含 ProjectedLayer 的网络。

相关函数

  • trainnet:用于训练网络。
  • dlnetwork:定义深度学习网络。
  • plot:可视化网络架构。
  • summary:网络总结。
  • NetworkAnalysis:网络分析结果对象。

小贴士

  • 若要交互式地构建和可视化深度学习网络,可以使用 Deep Network Designer 应用。
  • 如果在分析过程中遇到错误,确保网络结构正确连接并适配输入数据。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值