【caffe】OpenCV Load caffe model

上一篇,我们介绍了opencv_contrib中的模块在windows下的编译,也提到了其中的dnn模块可以读取caffe的训练模型用于目标检测,这里我们具体介绍一下如何使用dnn读取caffe模型并进行目标分类。


代码如下:(代码主要来自参考[2]和[3]):

#include <opencv2/dnn.hpp>  
#include <opencv2/imgproc.hpp>  
#include <opencv2/highgui.hpp>  

#include <fstream>  
#include <iostream>  
#include <cstdlib>  

/* Find best class for the blob (i. e. class with maximal probability) */
void getMaxClass(cv::dnn::Blob &probBlob, int *classId, double *classProb)
{
	cv::Mat probMat = probBlob.matRefConst().reshape(1, 1); //reshape the blob to 1x1000 matrix  
	cv::Point classNumber;

	cv::minMaxLoc(probMat, NULL, classProb, NULL, &classNumber);
	*classId = classNumber.x;
}

std::vector<cv::String> readClassNames(const char *filename = "synset_words.txt")
{
	std::vector<cv::String> classNames;

	std::ifstream fp(filename);
	if (!fp.is_open())
	{
		std::cerr << "File with classes labels not found: " << filename << std::endl;
		exit(-1);
	}

	std::string name;
	while (!fp.eof())
	{
		std::getline(fp, name);
		if (name.length())
			classNames.push_back(name.substr(name.find(' ') + 1));
	}

	fp.close();
	return classNames;
}

int main(int argc, char **argv)
{
	void cv::dnn::initModule();

	cv::String modelTxt = "bvlc_googlenet.prototxt";
	cv::String modelBin = "bvlc_googlenet.caffemodel";
	cv::String imageFile = "space_shuttle.jpg";

	cv::dnn::Net net = cv::dnn::readNetFromCaffe(modelTxt, modelBin);

	if (net.empty())
	{
		std::cerr << "Can't load network by using the following files: " << std::endl;
		std::cerr << "prototxt:   " << modelTxt << std::endl;
		std::cerr << "caffemodel: " << modelBin << std::endl;
		std::cerr << "bvlc_googlenet.caffemodel can be downloaded here:" << std::endl;
		std::cerr << "http://dl.caffe.berkeleyvision.org/bvlc_googlenet.caffemodel" << std::endl;
		exit(-1);
	}

	//! [Prepare blob]  
	cv::Mat img = cv::imread(imageFile, cv::IMREAD_COLOR);
	if (img.empty())
	{
		std::cerr << "Can't read image from the file: " << imageFile << std::endl;
		exit(-1);
	}

	cv::resize(img, img, cv::Size(224, 224));        
	cv::dnn::Blob inputBlob = cv::dnn::Blob(img);   //Convert Mat to dnn::Blob image batch  
	//! [Prepare blob]  

	//! [Set input blob]  
	net.setBlob(".data", inputBlob);        //set the network input  
	//! [Set input blob]  

	//! [Make forward pass]  
	net.forward();                          //compute output  
	//! [Make forward pass]  

	//! [Gather output]  
	cv::dnn::Blob prob = net.getBlob("prob");   //gather output of "prob" layer  

	int classId;
	double classProb;
	getMaxClass(prob, &classId, &classProb);//find the best class  
	//! [Gather output]  

	//! [Print results]  
	std::vector<cv::String> classNames = readClassNames();
	std::cout << "Best class: #" << classId << " '" << classNames.at(classId) << "'" << std::endl;
	std::cout << "Probability: " << classProb * 100 << "%" << std::endl;

	//! [Print results]  

	return 0;
} //main


代码详解 :
1、首先需要下载GoogLeNet模型及分类相关文件,可以从官网下载(或复制粘贴):   bvlc_googlenet.prototxtbvlc_googlenet.caffemodel以及synset_words.txt.也可以直接下载我长传的打包好的资源(包括了2中的图片)

2、下载待检测图片文件,如下:


Buran space shuttle

3、读取.protxt文件和.caffemodel文件:

cv::dnn::Net net = cv::dnn::readNetFromCaffe(modelTxt, modelBin);

4、检查网络是否读取成功:

	if (net.empty())
	{
		std::cerr << "Can't load network by using the following files: " << std::endl;
		std::cerr << "prototxt:   " << modelTxt << std::endl;
		std::cerr << "caffemodel: " << modelBin << std::endl;
		std::cerr << "bvlc_googlenet.caffemodel can be downloaded here:" << std::endl;
		std::cerr << "http://dl.caffe.berkeleyvision.org/bvlc_googlenet.caffemodel" << std::endl;
		exit(-1);
	}

5、读取图片并将其转换成GoogleNet可以读取的blob:

	cv::Mat img = cv::imread(imageFile, cv::IMREAD_COLOR);
	if (img.empty())
	{
		std::cerr << "Can't read image from the file: " << imageFile << std::endl;
		exit(-1);
	}

	cv::resize(img, img, cv::Size(224, 224));        
	cv::dnn::Blob inputBlob = cv::dnn::Blob(img);   //Convert Mat to dnn::Blob image batch  


6、将blob传递给网络:

	net.setBlob(".data", inputBlob);        //set the network input  

7、前向传递:

	net.forward();                          //compute output  

8、分类:

	getMaxClass(prob, &classId, &classProb);//find the best class  

9、打印分类结果:

	std::vector<cv::String> classNames = readClassNames();
	std::cout << "Best class: #" << classId << " '" << classNames.at(classId) << "'" << std::endl;
	std::cout << "Probability: " << classProb * 100 << "%" << std::endl;

运行,报错如下:



找了很久,终于在参考[3]中找到了解决方案,原因是这里将图像数据转换成blob的方法来自于老版本,在新版本中不兼容。解决方法如下:cv::dnn::Blob(img) 用cv::dnn::Blob::fromImages(img)替换掉。


修改后,再运行,结果如下:



参考:

[1] http://docs.opencv.org/trunk/d5/de7/tutorial_dnn_googlenet.html
[2] http://blog.csdn.net/langb2014/article/details/50555910
[3] https://github.com/opencv/opencv_contrib/issues/749


-----------------------------------------

2017.07.24


  • 4
    点赞
  • 10
    收藏
    觉得还不错? 一键收藏
  • 11
    评论
评论 11
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值