DNN模块

1.OpenCV3.3 DNN模块介绍

在OpenCV3.3版本发布中把DNN模块从扩展模块移到了OpenCV正式发布模块中,当前DNN模块最早来自Tiny-dnn,可以加载预先训练好的Caffe模型数据,OpenCV做了近一步扩展支持所有主流的深度学习框架训练生成与导出模型数据加载,常见的有如下:

Caffe
TensorFlow
Torch/PyTorch 

OpenCV中DNN模块已经支持与测试过这些常见的网络模块:

AlexNet
GoogLeNet v1 (also referred to as Inception-5h)
ResNet-34/50/...
SqueezeNet v1.1
VGG-based FCN (semantical segmentation network)
ENet (lightweight semantical segmentation network)
VGG-based SSD (object detection network)

MobileNet-based SSD (light-weight object detection network)

 

OpenCV中DNN模块的位置:

opencv3.3\sources\samples\dnn

 

函数和框架
下面是我们将用到的一些函数。


在dnn中从磁盘加载图片:
cv2.dnn.blobFromImage
cv2.dnn.blobFromImages


用“create”方法直接从各种框架中导出模型:
cv2.dnn.createCaffeImporter
cv2.dnn.createTensorFlowImporter
cv2.dnn.createTorchImporter


使用“读取”方法从磁盘直接加载序列化模型:
cv2.dnn.readNetFromCaffe
cv2.dnn.readNetFromTensorFlow
cv2.dnn.readNetFromTorch
cv2.dnn.readhTorchBlob


从磁盘加载完模型之后,可以用.forward方法来向前传播我们的图像,获取分类结果。

 

2.OpenCV3.3 dnn模块调用caffe model

以D:\opencv3.3\sources\samples\dnn\caffe_googlenet.cpp为例:

 

 
    /**M///

    //

    // IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.

    //

    // By downloading, copying, installing or using the software you agree to this license.

    // If you do not agree to this license, do not download, install,

    // copy or use the software.

    //

    //

    // License Agreement

    // For Open Source Computer Vision Library

    //

    // Copyright (C) 2013, OpenCV Foundation, all rights reserved.

    // Third party copyrights are property of their respective owners.

    //

    // Redistribution and use in source and binary forms, with or without modification,

    // are permitted provided that the following conditions are met:

    //

    // * Redistribution's of source code must retain the above copyright notice,

    // this list of conditions and the following disclaimer.

    //

    // * Redistribution's in binary form must reproduce the above copyright notice,

    // this list of conditions and the following disclaimer in the documentation

    // and/or other materials provided with the distribution.

    //

    // * The name of the copyright holders may not be used to endorse or promote products

    // derived from this software without specific prior written permission.

    //

    // This software is provided by the copyright holders and contributors "as is" and

    // any express or implied warranties, including, but not limited to, the implied

    // warranties of merchantability and fitness for a particular purpose are disclaimed.

    // In no event shall the Intel Corporation or contributors be liable for any direct,

    // indirect, incidental, special, exemplary, or consequential damages

    // (including, but not limited to, procurement of substitute goods or services;

    // loss of use, data, or profits; or business interruption) however caused

    // and on any theory of liability, whether in contract, strict liability,

    // or tort (including negligence or otherwise) arising in any way out of

    // the use of this software, even if advised of the possibility of such damage.

    //

    //M*/

    #include <opencv2/dnn.hpp>

    #include <opencv2/imgproc.hpp>

    #include <opencv2/highgui.hpp>

    #include <opencv2/core/utils/trace.hpp>

    using namespace cv;

    using namespace cv::dnn;

    #include <fstream>

    #include <iostream>

    #include <cstdlib>

    using namespace std;

    /* Find best class for the blob (i. e. class with maximal probability) */

    static void getMaxClass(const Mat &probBlob, int *classId, double *classProb)

    {

    Mat probMat = probBlob.reshape(1, 1); //reshape the blob to 1x1000 matrix

    Point classNumber;

    minMaxLoc(probMat, NULL, classProb, NULL, &classNumber);

    *classId = classNumber.x;

    }

    static std::vector<String> readClassNames(const char *filename = "synset_words.txt")

    {

    std::vector<String> classNames;

    std::ifstream fp(filename);

    if (!fp.is_open())

    {

    std::cerr << "File with classes labels not found: " << filename << std::endl;

    exit(-1);

    }

    std::string name;

    while (!fp.eof())

    {

    std::getline(fp, name);

    if (name.length())

    classNames.push_back( name.substr(name.find(' ')+1) );

    }

    fp.close();

    return classNames;

    }

    int main(int argc, char **argv)

    {

    CV_TRACE_FUNCTION();

    String modelTxt = "bvlc_googlenet.prototxt";

    String modelBin = "bvlc_googlenet.caffemodel";

    String imageFile = (argc > 1) ? argv[1] : "space_shuttle.jpg";

    Net net;

    try {

    //! [Read and initialize network]

    net = dnn::readNetFromCaffe(modelTxt, modelBin);

    //! [Read and initialize network]

    }

    catch (cv::Exception& e) {

    std::cerr << "Exception: " << e.what() << std::endl;

    //! [Check that network was read successfully]

    if (net.empty())

    {

    std::cerr << "Can't load network by using the following files: " << std::endl;

    std::cerr << "prototxt: " << modelTxt << std::endl;

    std::cerr << "caffemodel: " << modelBin << std::endl;

    std::cerr << "bvlc_googlenet.caffemodel can be downloaded here:" << std::endl;

    std::cerr << "http://dl.caffe.berkeleyvision.org/bvlc_googlenet.caffemodel" << std::endl;

    exit(-1);

    }

    //! [Check that network was read successfully]

    }

    //! [Prepare blob]

    Mat img = imread(imageFile);

    if (img.empty())

    {

    std::cerr << "Can't read image from the file: " << imageFile << std::endl;

    exit(-1);

    }

    //GoogLeNet accepts only 224x224 BGR-images

    Mat inputBlob = blobFromImage(img, 1.0f, Size(224, 224),

    Scalar(104, 117, 123), false); //Convert Mat to batch of images

    //! [Prepare blob]

    Mat prob;

    cv::TickMeter t;

    for (int i = 0; i < 10; i++)

    {

    CV_TRACE_REGION("forward");

    //! [Set input blob]

    net.setInput(inputBlob, "data"); //set the network input

    //! [Set input blob]

    t.start();

    //! [Make forward pass]

    prob = net.forward("prob"); //compute output

    //! [Make forward pass]

    t.stop();

    }

    //! [Gather output]

    int classId;

    double classProb;

    getMaxClass(prob, &classId, &classProb);//find the best class

    //! [Gather output]

    //! [Print results]

    std::vector<String> classNames = readClassNames();

    std::cout << "Best class: #" << classId << " '" << classNames.at(classId) << "'" << std::endl;

    std::cout << "Probability: " << classProb * 100 << "%" << std::endl;

    //! [Print results]

    std::cout << "Time: " << (double)t.getTimeMilli() / t.getCounter() << " ms (average from " << t.getCounter() << " iterations)" << std::endl;

    return 0;

    } //main


 

 

需要修改的位置:

位置1

 

 
  1. Mat inputBlob = blobFromImage(img, 1.0f, Size(224, 224),

  2. Scalar(104, 117, 123), false);

Scalar(104,117,123)为make_imagenet_mean.sh计算出来的均值

 

—————————————————————————————————————————

位置2

 

static std::vector<String> readClassNames(const char *filename = "synset_words.txt")

标签文件名

 

格式为:

 

 
  1. 0 合格

  2. 1 优秀

  3. 2 良好

  4. 3 不合格

—————————————————————————————————————————

位置3

 
  1. String modelTxt = "bvlc_googlenet.prototxt";

  2. String modelBin = "bvlc_googlenet.caffemodel";

  3. String imageFile = (argc > 1) ? argv[1] : "space_shuttle.jpg";

googlenet路径,图片路径

 

 

运行结果:

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值