深入 CoreML 模型定义

Core ML

Core ML是apple在iOS和MAC上的机器学习框架, 开发者可以使用Core ML将机器学习模型集成到应用中. Core ML架构于Accelerate, BNNS, Metal之上, 是apple针对其硬件深度优化后的框架, 可以大大加速开发者的工作, 让开发者集中精力于模型的训练和优化上.
Apple CoreML Architecture

Core ML所支持的模型文件是后缀为.mlmodel的文件, 使用非常简单, 将模型文件拖入工程即可. 不过, 这个文件里头定义的到底是啥呢?
要了解Core ML是如何工作的, Core ML能支持什么样的学习模型, 在使用Core ML的时候怎么更快的trouble-shooting, 如何在Core ML和其他平台的机器学习框架之间切换,则需要了解Core ML的模型文件的规范.

Core ML Format

Apple官方的Core ML模型的Spec在这里: Core ML Model Format Specification
这种文档看起来比较枯燥, 让我们从一个例子入手来学习.mlmodel文件里的奥秘.
AgeNet 是一个开放的识别人的年龄和性别的模型, 可以在下面的链接下载:
https://github.com/volvet/InsideCoreMLModel/blob/master/python/AgeNet.mlmodel

.mlmode文件实质上是protobuf文件, proto文件定义于
https://github.com/apple/coremltools/tree/master/mlmodel/format
使用下面的命令将Core ML的proto文件转为 python代码

protoc --python_out=. *.proto

然后就可以开始解析mlmodel文件了.

最上层的Message是Model, 定义如下:

message Model {
    int32 specificationVersion = 1;
    ModelDescription description = 2;
    // start at 200 here
    // model specific parameters:
    oneof Type {
        // pipeline starts at 200
        PipelineClassifier pipelineClassifier = 200;
        PipelineRegressor pipelineRegressor = 201;
        Pipeline pipeline = 202;

        // regressors start at 300
        GLMRegressor glmRegressor = 300;
        SupportVectorRegressor supportVectorRegressor = 301;
        TreeEnsembleRegressor treeEnsembleRegressor = 302;
        NeuralNetworkRegressor neuralNetworkRegressor = 303;
        BayesianProbitRegressor bayesianProbitRegressor = 304;

        // classifiers start at 400
        GLMClassifier glmClassifier = 400;
        SupportVectorClassifier supportVectorClassifier = 401;
        TreeEnsembleClassifier treeEnsembleClassifier = 402;
        NeuralNetworkClassifier neuralNetworkClassifier = 403;

        // generic models start at 500
        NeuralNetwork neuralNetwork = 500;

        // Custom model
        CustomModel customModel = 555;
        ...
      }
    }

所以解析就从Model开始

if __name__ == '__main__':
    model = Model_pb2.Model()
    with open('AgeNet.mlmodel', 'rb') as f:
        model.ParseFromString(f.read())
    
    print(model.description)
    print(model.specificationVersion)
    if model.HasField('neuralNetworkClassifier'):
        parseNeuralNetworkC
  • 1
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值