上一篇,我们介绍了opencv_contrib中的模块在windows下的编译,也提到了其中的dnn模块可以读取caffe的训练模型用于目标检测,这里我们具体介绍一下如何使用dnn读取caffe模型并进行目标分类。
代码如下:(代码主要来自参考[2]和[3]):
#include <opencv2/dnn.hpp>
#include <opencv2/imgproc.hpp>
#include <opencv2/highgui.hpp>
#include <fstream>
#include <iostream>
#include <cstdlib>
/* Find best class for the blob (i. e. class with maximal probability) */
void getMaxClass(cv::dnn::Blob &probBlob, int *classId, double *classProb)
{
cv::Mat probMat = probBlob.matRefConst().reshape(1, 1); //reshape the blob to 1x1000 matrix
cv::Point classNumber;
cv::minMaxLoc(probMat, NULL, classProb, NULL, &classNumber);
*classId = classNumber.x;
}
std::vector<cv::String> readClassNames(const char *filename = "synset_words.txt")
{
std::vector<cv::String> classNames;
std::ifstream fp(filename);
if (!fp.is_open())
{
std::cerr << "File with classes labels not found: " << filename << std::endl;
exit(-1);
}
std::string name;
while (!fp.eof())
{
std::getline(fp, name);
if (name.length())
classNames.push_back(name.substr(name.find(' ') + 1));
}
fp.close();
return classNames;
}
int main(int argc, char **argv)
{
void cv::dnn::initModule();
cv::String modelTxt = "bvlc_googlenet.prototxt";
cv::String modelBin = "bvlc_googlenet.caffemodel";
cv::String imageFile = "space_shuttle.jpg";
cv::dnn::Net net = cv::dnn::readNetFromCaffe(modelTxt, modelBin);
if (net.empty())
{
std::cerr << "Can't load network by using the following files: " << std::endl;
std::cerr << "prototxt: " << modelTxt << std::endl;
std::cerr << "caffemodel: " << modelBin << std::endl;
std::cerr << "bvlc_googlenet.caffemodel can be downloaded here:" << std::endl;
std::cerr << "http://dl.caffe.berkeleyvision.org/bvlc_googlenet.caffemodel" << std::endl;
exit(-1);
}
//! [Prepare blob]
cv::Mat img = cv::imread(imageFile, cv::IMREAD_COLOR);
if (img.empty())
{
std::cerr << "Can't read image from the file: " << imageFile << std::endl;
exit(-1);
}
cv::resize(img, img, cv::Size(224, 224));
cv::dnn::Blob inputBlob = cv::dnn::Blob(img); //Convert Mat to dnn::Blob image batch
//! [Prepare blob]
//! [Set input blob]
net.setBlob(".data", inputBlob); //set the network input
//! [Set input blob]
//! [Make forward pass]
net.forward(); //compute output
//! [Make forward pass]
//! [Gather output]
cv::dnn::Blob prob = net.getBlob("prob"); //gather output of "prob" layer
int classId;
double classProb;
getMaxClass(prob, &classId, &classProb);//find the best class
//! [Gather output]
//! [Print results]
std::vector<cv::String> classNames = readClassNames();
std::cout << "Best class: #" << classId << " '" << classNames.at(classId) << "'" << std::endl;
std::cout << "Probability: " << classProb * 100 << "%" << std::endl;
//! [Print results]
return 0;
} //main
1、首先需要下载GoogLeNet模型及分类相关文件,可以从官网下载(或复制粘贴): bvlc_googlenet.prototxt、bvlc_googlenet.caffemodel以及synset_words.txt.也可以直接下载我长传的打包好的资源(包括了2中的图片)
2、下载待检测图片文件,如下:
Buran space shuttle
3、读取.protxt文件和.caffemodel文件:
cv::dnn::Net net = cv::dnn::readNetFromCaffe(modelTxt, modelBin);
4、检查网络是否读取成功:
if (net.empty())
{
std::cerr << "Can't load network by using the following files: " << std::endl;
std::cerr << "prototxt: " << modelTxt << std::endl;
std::cerr << "caffemodel: " << modelBin << std::endl;
std::cerr << "bvlc_googlenet.caffemodel can be downloaded here:" << std::endl;
std::cerr << "http://dl.caffe.berkeleyvision.org/bvlc_googlenet.caffemodel" << std::endl;
exit(-1);
}
5、读取图片并将其转换成GoogleNet可以读取的blob:
cv::Mat img = cv::imread(imageFile, cv::IMREAD_COLOR);
if (img.empty())
{
std::cerr << "Can't read image from the file: " << imageFile << std::endl;
exit(-1);
}
cv::resize(img, img, cv::Size(224, 224));
cv::dnn::Blob inputBlob = cv::dnn::Blob(img); //Convert Mat to dnn::Blob image batch
6、将blob传递给网络:
net.setBlob(".data", inputBlob); //set the network input
7、前向传递:
net.forward(); //compute output
8、分类:
getMaxClass(prob, &classId, &classProb);//find the best class
9、打印分类结果:
std::vector<cv::String> classNames = readClassNames();
std::cout << "Best class: #" << classId << " '" << classNames.at(classId) << "'" << std::endl;
std::cout << "Probability: " << classProb * 100 << "%" << std::endl;
运行,报错如下:
找了很久,终于在参考[3]中找到了解决方案,原因是这里将图像数据转换成blob的方法来自于老版本,在新版本中不兼容。解决方法如下:将cv::dnn::Blob(img) 用cv::dnn::Blob::fromImages(img)替换掉。
修改后,再运行,结果如下:
参考:
[1] http://docs.opencv.org/trunk/d5/de7/tutorial_dnn_googlenet.html
[2] http://blog.csdn.net/langb2014/article/details/50555910
[3] https://github.com/opencv/opencv_contrib/issues/749
-----------------------------------------
2017.07.24