mxnet使用模型预测的初步心得

主要步骤:

  1. 加载模型结构json文件
  2. 加载模型参数params文件
  3. 加载标签synset.txt文件
  4. 创建predictor或者predictor handler
  5. 加载目标文件
  6. 预测(predict)
  7. 获取预测结果
  8. 如果是分类预测模型,则需要根据输出的向量取出最大可能的位置,根据synset文件确定分析结果
  9. 输出预测结果

上述步骤mxnetAPI:

mxnet的API对于c++开发者来说分为c分格的API和C++分格API,下面分开介绍

C分格API
  • 创建预测句柄,通过out获取预测器句柄,使用预测期句柄进行预测

     int MXPredCreate   (   const char *    symbol_json_str,    const void *    param_bytes,    int     param_size,    int     dev_type,    int     dev_id,    uint32_t    num_input_nodes,    const char **   input_keys,    const uint32_t *    input_shape_indptr,    const uint32_t *    input_shape_data,    PredictorHandle *   out )          /*   参量    symbol_json_str  模型结构json文件。    param_bytes      模型参数params文件。    param_size       参数文件的大小。    dev_type         分析设备类型,1:CPU,2:GPU    dev_id           预测器的设备ID。    num_input_nodes  网络的输入节点数,对于正向传播为1。    input_keys       输入参数的名称。对于正向传播,为{“ data”}    input_shape_indptr  每个输入节点的形状的索引指针。该数组的长度= num_input_nodes +1。对于需要4维输入的正向传播,其值为{0,4}。    input_shape_data    每个输入节点的形状的扁平数据。对于需要4维输入的前馈网,这是形状数据。    out              创建的预测器handler。    返回值:    成功时为0,失败时为-1。   */ 
    
  • 设置要预测的图片

    int MXPredSetInput(PredictorHandle handle,                             const char* key,                             const mx_float* data,                             mx_uint size);/*!    handle 预测期句柄    key    正向传播默认为 data    data   指向要设置的数据的指针    size   数据数组的大小,用于安全检查。        返回值         成功 0 ,失败 -1 */
    
  • 获取输出结果

     int MXPredGetOutputShape   (   PredictorHandle     handle,    uint32_t    index,    uint32_t **     shape_data,    uint32_t *  shape_ndim )    /* 获取预测的输出值。    handle      预测器的句柄。    index       输出节点的索引,如果只有一个输出,则设置为0。    shape_data  用户分配的数据来保存输出。    shape_ndim  数据数组的大小,用于安全检查。        返回值        成功时为0,失败时为-1。 */
    
C++API
  • symble对象用来管理j模型的json文件内容

    static Symbol mxnet::cpp::Symbol::Load(const std::string & file_name) //使用Load方法来加载模型文件
    
  • NDArray对象来管理模型params文件内容

    static void mxnet::cpp::NDArray::Load   (   const std::string &     file_name,std::vector< NDArray > *    array_list = nullptr,std::map< std::string, NDArray > *  array_map = nullptr )   /*file_name       目标文件.array_list      默认填空array_map       出参,从params文件返的一个map*/
    
  • MXDataIter对象用来管理目标分析图片

    //设置对应参数的值Operator& mxnet::cpp::Operator::SetParam    (   const std::string &     name,const T &   value )           //创建目标图片的迭代器MXDataIter mxnet::cpp::MXDataIter::CreateDataIter   (       )   //SetParam 必须在CreateDataIter前调用
    
  • Executor 预测期对象

    std::map<std::string, NDArray> mxnet::cpp::Executor::aux_dict   (       )   std::map<std::string, NDArray> mxnet::cpp::Executor::arg_dict   (       )   /*从连续的CPU内存区域执行同步复制。该函数将在执行复制之前调用WaitToWrite。这对于从没有被NDArray封装的现有内存区域复制数据非常有用(因此没有跟踪依赖项)。*/SyncCopyFromCPU()//进行正向传播,对目标图片根据模型进行预测。void mxnet::Executor::Forward   (   bool    is_train    )   
    
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值