demo的作用是对目标检测到的目标位置进行precision,recall和fscore的计算;
demo功能就涉及:1)ground truth xml中的box的读取;
2)计算预测目标框和ground truth box的IOU区域(demo中IOU阈值为0.5)
3)IOU计算原理;
demo代码如下:
#include<iostream>
#include<QString>
#include<vector>
#include<QFile>
#include<QXmlStreamReader>
using namespace std;
typedef struct Bnd_Box_
{
int xmin;
int ymin;
int xmax;
int ymax;
}Bnd_Box;
typedef struct Obj_Box_
{
int x;
int y;
int width;
int height;
}Obj_Box;
typedef struct Xml_Image_Head_
{
QString folder;
QString filename;
QString path;
QString source;
QString database;
int image_width;
int image_height;
int image_depth;
int image_size;
int segmented;
}Xml_Image_Head;
typedef struct Xml_Box_
{
QString class_name;
QString pose;
int truncated;
int difficult;
Bnd_Box box;
}Xml_Box;
void parseBlog( QXmlStreamReader &reader, Xml_Image_Head &image_head, vector< Xml_Box > &objs )
{
bool ok;
Xml_Box obj;
QString strElementName;
QString strCharacters;
objs.clear();
while( !reader.atEnd() )
{
reader.readNext();
if (reader.isStartElement())
{ // 开始元素
strElementName = reader.name().toString();
reader.readNext();
if (reader.isCharacters())
{
strCharacters = reader.text().toString();
// qDebug() << QString("Characters : %1").arg(strCharacters);
}
// QString strText = reader.text().toString();
// qDebug() << QString::fromLocal8Bit("********** 开始元素<") << strElementName << QString::fromLocal8Bit("> ********** ") << strCharacters;
if( QString( "folder" ) == strElementName )
{
image_head.folder = strCharacters;
}
else if( QString( "filename" ) == strElementName )
{
image_head.filename = strCharacters;
}
else if( QString( "path" ) == strElementName )
{
image_head.path = strCharacters;
}
else if( QString( "source" ) == strElementName )
{
image_head.source = strCharacters;
}
else if( QString( "database" ) == strElementName )
{
image_head.database = strCharacters;
}
else if( QString( "width" ) == strElementName )
{
image_head.image_width = strCharacters.toInt( &ok );
if( false == ok )
{
image_head.image_width = 0;
}
}
else if( QString( "height" ) == strElementName )
{
image_head.image_height = strCharacters.toInt( &ok );
if( false == ok )
{
image_head.image_height = 0;
}
}
else if( QString( "depth" ) == strElementName )
{
image_head.image_depth = strCharacters.toInt( &ok );
if( false == ok )
{
image_head.image_depth = 0;
}
}
else if( QString( "segmented" ) == strElementName )
{
image_head.segmented = strCharacters.toInt( &ok );
if( false == ok )
{
image_head.segmented = 0;
}
}
else if( QString( "object" ) == strElementName )
{
// TODO:: box
bool is_do = true;
while( is_do && !reader.atEnd() )
{
reader.readNext();
if (reader.isStartElement())
{ // 开始元素
strElementName = reader.name().toString();
reader.readNext();
if (reader.isCharacters())
{
strCharacters = reader.text().toString();
}
// qDebug() << QString::fromLocal8Bit("********** 开始元素<") << strElementName << QString::fromLocal8Bit("> ********** ") << strCharacters;
if( QString( "name" ) == strElementName )
{
obj.class_name = strCharacters;
}
else if( QString( "pose" ) == strElementName )
{
obj.pose = strCharacters;
}
else if( QString( "truncated" ) == strElementName )
{
obj.truncated = strCharacters.toInt( &ok );
if( false == ok )
{
obj.truncated = 0;
}
}
else if( QString( "difficult" ) == strElementName )
{
obj.difficult = strCharacters.toInt( &ok );
if( false == ok )
{
obj.difficult = 0;
}
}
else if( QString( "xmin" ) == strElementName )
{
obj.box.xmin = strCharacters.toInt( &ok );
if( false == ok )
{
obj.box.xmin = 0;
}
}
else if( QString( "ymin" ) == strElementName )
{
obj.box.ymin = strCharacters.toInt( &ok );
if( false == ok )
{
obj.box.ymin = 0;
}
}
else if( QString( "xmax" ) == strElementName )
{
obj.box.xmax = strCharacters.toInt( &ok );
if( false == ok )
{
obj.box.xmax = 0;
}
}
else if( QString( "ymax" ) == strElementName )
{
obj.box.ymax = strCharacters.toInt( &ok );
if( false == ok )
{
obj.box.ymax = 0;
}
}
}
else if ( reader.isEndElement() )
{ // 结束元素
QString strElementName = reader.name().toString();
// qDebug() << QString::fromLocal8Bit("********** 结束元素<") << strElementName << QString::fromLocal8Bit("> ********** ");
if( 0 == strElementName.compare( "object" ) )
{
is_do = false;
objs.push_back( obj );
}
}
}
}
}
else if (reader.isEntityReference())
{ // 实体引用
QString strName = reader.name().toString();
QString strText = reader.text().toString();
// qDebug() << QString("EntityReference : %1(%2)").arg(strName).arg(strText);
}
else if (reader.isCDATA())
{ // CDATA
// QString strCDATA = reader.text().toString();
// qDebug() << QString("CDATA : %1").arg(strCDATA);
reader.readNext();
if (reader.isCharacters())
{
QString strCharacters = reader.text().toString();
// qDebug() << QString("Characters : %1").arg(strCharacters);
}
}
else if (reader.isEndElement())
{ // 结束元素
QString strElementName = reader.name().toString();
// qDebug() << QString::fromLocal8Bit("********** 结束元素<") << strElementName << QString::fromLocal8Bit("> ********** ");
}
else if (reader.isDTD())
{ // CDATA
QString strDTD = reader.text().toString();
// qDebug() << QString("DTD : %1").arg(strDTD);
reader.readNext();
if (reader.isCharacters())
{
QString strCharacters = reader.text().toString();
// qDebug() << QString("Characters : %1").arg(strCharacters);
}
}
}
}
void analysis_xml( QString xml_name, Xml_Image_Head &image_head, vector< Xml_Box > &box )
{
QFile file_io( xml_name );
file_io.open( QIODevice::ReadOnly );
QXmlStreamReader xml_read( &file_io );
if( file_io.isOpen() )
{
while( !xml_read.atEnd() )
{
QXmlStreamReader::TokenType nType = xml_read.readNext();
switch( nType )
{
case QXmlStreamReader::StartDocument:
{ // 开始文档
// qDebug() << QString::fromLocal8Bit( "********** 开始文档(XML 声明) ********** ");
// XML 声明
QString strVersion = xml_read.documentVersion().toString();
QString strEncoding = xml_read.documentEncoding().toString();
bool bAlone = xml_read.isStandaloneDocument();
// qDebug() << QString::fromLocal8Bit("版本:%1 编码:%2 Standalone:%3")
// .arg(strVersion).arg(strEncoding).arg(bAlone) << "\r\n";
break;
}
case QXmlStreamReader::Comment:
{ // 注释
// qDebug() << QString::fromLocal8Bit("********** 注释 ********** ");
QString strComment = xml_read.text().toString();
// qDebug() << strComment << "\r\n";
break;
}
case QXmlStreamReader::ProcessingInstruction:
{ // 处理指令
// qDebug() << QString::fromLocal8Bit("********** 处理指令 ********** ");
QString strProcInstr = xml_read.processingInstructionData().toString();
// qDebug() << strProcInstr << "\r\n";
break;
}
case QXmlStreamReader::DTD:
{ // DTD
// qDebug() << QString::fromLocal8Bit("********** DTD ********** ");
QString strDTD = xml_read.text().toString();
// QXmlStreamNotationDeclarations notationDeclarations = xml_read.notationDeclarations(); // 符号声明
// QXmlStreamEntityDeclarations entityDeclarations = xml_read.entityDeclarations(); // 实体声明
// DTD 声明
QString strDTDName = xml_read.dtdName().toString();
QString strDTDPublicId = xml_read.dtdPublicId().toString(); // DTD 公开标识符
QString strDTDSystemId = xml_read.dtdSystemId().toString(); // DTD 系统标识符
// qDebug() << QString::fromLocal8Bit("DTD : %1").arg(strDTD);
// qDebug() << QString::fromLocal8Bit("DTD 名称 : %1").arg(strDTDName);
// qDebug() << QString::fromLocal8Bit("DTD 公开标识符 : %1").arg(strDTDPublicId);
// qDebug() << QString::fromLocal8Bit("DTD 系统标识符 : %1").arg(strDTDSystemId);
// qDebug() << "\r\n";
break;
}
case QXmlStreamReader::StartElement:
{ // 开始元素
QString strElementName = xml_read.name().toString();
if (QString::compare( strElementName, "annotation") == 0)
{ // 根元素
// qDebug() << QString::fromLocal8Bit( "********** 开始元素<annotation> ********** " );
QXmlStreamAttributes attributes = xml_read.attributes();
if (attributes.hasAttribute("object"))
{
QString strVersion = attributes.value("Version").toString();
// qDebug() << QString::fromLocal8Bit("属性:Version(%1)").arg(strVersion);
}
parseBlog(xml_read, image_head, box );
}
break;
}
case QXmlStreamReader::EndDocument:
{ // 结束文档
// qDebug() << QString::fromLocal8Bit("********** 结束文档 ********** ");
break;
}
default:
break;
}
}
}
}
typedef struct Obj_Info_
{
Obj_Box obj_box;
float obj_prob;
int obj_class;
}Obj_Info;
float calcIOU(int prediction_xmin,int prediction_ymin,int prediction_width,int prediction_height,
int GT_xmin, int GT_ymin, int GT_width, int GT_height)
{
float IOU = 0.0;
float eps = 1e-5;
int center_point1_x = prediction_xmin + 0.5 * prediction_width;
int center_point1_y = prediction_ymin + 0.5 * prediction_height;
int center_point2_x = GT_xmin + 0.5 * GT_width;
int center_point2_y = GT_ymin + 0.5 * GT_height;
if (((abs(center_point1_x - center_point2_x) < ((prediction_width + GT_width) / 2.0)) and (
abs(center_point1_y - center_point2_y) < ((prediction_height + GT_height) / 2.0))))
{
int lu_x_inter = max(prediction_xmin, GT_xmin);
int lu_y_inter = max(prediction_ymin, GT_ymin);
int rb_x_inter = min((prediction_xmin + prediction_width),
(GT_xmin + GT_width));
int rb_y_inter = min((prediction_ymin + prediction_height),
(GT_ymin + GT_height));
int inter_w = abs(rb_x_inter - lu_x_inter);
int inter_h = abs(rb_y_inter - lu_y_inter);
int inter_square = inter_w * inter_h;
int union_square = (prediction_width * prediction_height) + (GT_width * GT_height) - inter_square;
IOU = 1.0 * inter_square / union_square;
cout<<"calcIOU:"<<IOU<<endl;
}
else
{
IOU = 0.0;
}
return IOU;
}
int main()
{
Xml_Image_Head image_head;
vector< Xml_Box > boxes;
string img_file = "child_hug_vid1_cut_person_child_1_2_100000001.jpg";
string root_path_xml ="/home/fuxueping/sdb/Caffe_Project/test_image_all_xml/" ;
int i = img_file.find(".jpg");
string xml_name(img_file.substr(0,i-1));
string xml_file = root_path_xml+xml_name + ".xml";
QString xml_dir = QString::fromStdString(xml_file);
analysis_xml( xml_dir, image_head, boxes );
int total = 0;
int correct = 0;
int proposal = 0;
int out_obj_num;
proposal += out_obj_num;
total += boxes.size();
Obj_Info obj_info[10];
int num_box = 0;
float iou_thresh = 0.5;
float eps = 0.000000001;
for(int i = 0; i < boxes.size();i++)
{
float best_iou = 0;
for (int k=0; k < out_obj_num; k++ )
{
float w = obj_info[ k ].obj_box.width;
float h = obj_info[ k ].obj_box.height;
float xmin = obj_info[ k ].obj_box.x;
float ymin = obj_info[ k ].obj_box.y;
float iou = calcIOU(xmin,ymin,w,h,
boxes[i].box.xmin,
boxes[i].box.ymin,
boxes[i].box.xmax-boxes[i].box.xmin,
boxes[i].box.ymax - boxes[i].box.ymin);
if(iou > best_iou)
{
best_iou = iou;
}
}
if((best_iou > iou_thresh)&&(num_box < out_obj_num))
{
correct += 1;
num_box += 1;
}
}
float precision = 1.0 × correct/(proposal + eps);
float recall = 1.0 * correct/(total + eps);
float fscore = 2.0 * precision * recall /(precision + recall + eps);
return 0;
}
xml类型如下:
<annotation>
<folder>child_hug_vid1_cut</folder>
<filename>2017-06-30_13-56-10_00686</filename>
<path>/home/csh/Desktop/label/py/child_hug_vid1_cut/2017-06-30_13-56-10_00686.jpg</path>
<source>
<database>Unknown</database>
</source>
<size>
<width>433</width>
<height>229</height>
<depth>3</depth>
</size>
<segmented>0</segmented>
<object>
<name>1</name>
<pose>Unspecified</pose>
<truncated>0</truncated>
<difficult>0</difficult>
<bndbox>
<xmin>165</xmin>
<ymin>46</ymin>
<xmax>251</xmax>
<ymax>193</ymax>
</bndbox>
</object>
<object>
<name>1</name>
<pose>Unspecified</pose>
<truncated>0</truncated>
<difficult>0</difficult>
<bndbox>
<xmin>161</xmin>
<ymin>22</ymin>
<xmax>332</xmax>
<ymax>202</ymax>
</bndbox>
</object>
</annotation>
这里是一个demo程序;
我在使用时,是通过图片的文件名然后获得xml的绝对路径,所以务必保证所有测试的图片是包含XML文件;