【caffe】OpenCV Load caffe model(转)

原文地址:http://blog.csdn.net/guduruyu/article/details/76006003

 

上一篇,我们介绍了opencv_contrib中的模块在windows下的编译,也提到了其中的dnn模块可以读取caffe的训练模型用于目标检测,这里我们具体介绍一下如何使用dnn读取caffe模型并进行目标分类。

 

代码如下:(代码主要来自参考[2]和[3]):

 

 
  1. #include <opencv2/dnn.hpp>

  2. #include <opencv2/imgproc.hpp>

  3. #include <opencv2/highgui.hpp>

  4.  
  5. #include <fstream>

  6. #include <iostream>

  7. #include <cstdlib>

  8.  
  9. /* Find best class for the blob (i. e. class with maximal probability) */

  10. void getMaxClass(cv::dnn::Blob &probBlob, int *classId, double *classProb)

  11. {

  12. cv::Mat probMat = probBlob.matRefConst().reshape(1, 1); //reshape the blob to 1x1000 matrix

  13. cv::Point classNumber;

  14.  
  15. cv::minMaxLoc(probMat, NULL, classProb, NULL, &classNumber);

  16. *classId = classNumber.x;

  17. }

  18.  
  19. std::vector<cv::String> readClassNames(const char *filename = "synset_words.txt")

  20. {

  21. std::vector<cv::String> classNames;

  22.  
  23. std::ifstream fp(filename);

  24. if (!fp.is_open())

  25. {

  26. std::cerr << "File with classes labels not found: " << filename << std::endl;

  27. exit(-1);

  28. }

  29.  
  30. std::string name;

  31. while (!fp.eof())

  32. {

  33. std::getline(fp, name);

  34. if (name.length())

  35. classNames.push_back(name.substr(name.find(' ') + 1));

  36. }

  37.  
  38. fp.close();

  39. return classNames;

  40. }

  41.  
  42. int main(int argc, char **argv)

  43. {

  44. void cv::dnn::initModule();

  45.  
  46. cv::String modelTxt = "bvlc_googlenet.prototxt";

  47. cv::String modelBin = "bvlc_googlenet.caffemodel";

  48. cv::String imageFile = "space_shuttle.jpg";

  49.  
  50. cv::dnn::Net net = cv::dnn::readNetFromCaffe(modelTxt, modelBin);

  51.  
  52. if (net.empty())

  53. {

  54. std::cerr << "Can't load network by using the following files: " << std::endl;

  55. std::cerr << "prototxt: " << modelTxt << std::endl;

  56. std::cerr << "caffemodel: " << modelBin << std::endl;

  57. std::cerr << "bvlc_googlenet.caffemodel can be downloaded here:" << std::endl;

  58. std::cerr << "http://dl.caffe.berkeleyvision.org/bvlc_googlenet.caffemodel" << std::endl;

  59. exit(-1);

  60. }

  61.  
  62. //! [Prepare blob]

  63. cv::Mat img = cv::imread(imageFile, cv::IMREAD_COLOR);

  64. if (img.empty())

  65. {

  66. std::cerr << "Can't read image from the file: " << imageFile << std::endl;

  67. exit(-1);

  68. }

  69.  
  70. cv::resize(img, img, cv::Size(224, 224));

  71. cv::dnn::Blob inputBlob = cv::dnn::Blob(img); //Convert Mat to dnn::Blob image batch

  72. //! [Prepare blob]

  73.  
  74. //! [Set input blob]

  75. net.setBlob(".data", inputBlob); //set the network input

  76. //! [Set input blob]

  77.  
  78. //! [Make forward pass]

  79. net.forward(); //compute output

  80. //! [Make forward pass]

  81.  
  82. //! [Gather output]

  83. cv::dnn::Blob prob = net.getBlob("prob"); //gather output of "prob" layer

  84.  
  85. int classId;

  86. double classProb;

  87. getMaxClass(prob, &classId, &classProb);//find the best class

  88. //! [Gather output]

  89.  
  90. //! [Print results]

  91. std::vector<cv::String> classNames = readClassNames();

  92. std::cout << "Best class: #" << classId << " '" << classNames.at(classId) << "'" << std::endl;

  93. std::cout << "Probability: " << classProb * 100 << "%" << std::endl;

  94.  
  95. //! [Print results]

  96.  
  97. return 0;

  98. } //main

 

 

代码详解:
1、首先需要下载GoogLeNet模型及分类相关文件,可以从官网下载(或复制粘贴):  bvlc_googlenet.prototxtbvlc_googlenet.caffemodel以及synset_words.txt.也可以直接下载我长传的打包好的资源(包括了2中的图片)

2、下载待检测图片文件,如下:

Buran space shuttle

3、读取.protxt文件和.caffemodel文件:

 

cv::dnn::Net net = cv::dnn::readNetFromCaffe(modelTxt, modelBin);

 

 

 

4、检查网络是否读取成功:

 

 
  1. if (net.empty())

  2. {

  3. std::cerr << "Can't load network by using the following files: " << std::endl;

  4. std::cerr << "prototxt: " << modelTxt << std::endl;

  5. std::cerr << "caffemodel: " << modelBin << std::endl;

  6. std::cerr << "bvlc_googlenet.caffemodel can be downloaded here:" << std::endl;

  7. std::cerr << "http://dl.caffe.berkeleyvision.org/bvlc_googlenet.caffemodel" << std::endl;

  8. exit(-1);

  9. }


5、读取图片并将其转换成GoogleNet可以读取的blob:

 

 

 
  1. cv::Mat img = cv::imread(imageFile, cv::IMREAD_COLOR);

  2. if (img.empty())

  3. {

  4. std::cerr << "Can't read image from the file: " << imageFile << std::endl;

  5. exit(-1);

  6. }

  7.  
  8. cv::resize(img, img, cv::Size(224, 224));

  9. cv::dnn::Blob inputBlob = cv::dnn::Blob(img); //Convert Mat to dnn::Blob image batch

 

 

 

 

 

6、将blob传递给网络:

 

	net.setBlob(".data", inputBlob);        //set the network input  

 

 

 

7、前向传递:

 

	net.forward();                          //compute output  

 

 

 

8、分类:

 

	getMaxClass(prob, &classId, &classProb);//find the best class  

 

 

 

9、打印分类结果:

 

 
  1. std::vector<cv::String> classNames = readClassNames();

  2. std::cout << "Best class: #" << classId << " '" << classNames.at(classId) << "'" << std::endl;

  3. std::cout << "Probability: " << classProb * 100 << "%" << std::endl;

 

 

 

运行,报错如下:

 

找了很久,终于在参考[3]中找到了解决方案,原因是这里将图像数据转换成blob的方法来自于老版本,在新版本中不兼容。解决方法如下:将cv::dnn::Blob(img) 用cv::dnn::Blob::fromImages(img)替换掉。

 

修改后,再运行,结果如下:

 

参考:

[1] http://docs.opencv.org/trunk/d5/de7/tutorial_dnn_googlenet.html
[2] http://blog.csdn.net/langb2014/article/details/50555910
[3] https://github.com/opencv/opencv_contrib/issues/749

 

-----------------------------------------

2017.07.24

 

 

  • 1
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值