C++ 目标检测中的precision,recall,fscore的计算

demo的作用是对目标检测到的目标位置进行precision,recall和fscore的计算;

demo功能就涉及:1)ground truth xml中的box的读取;

                              2)计算预测目标框和ground truth box的IOU区域(demo中IOU阈值为0.5)

                              3)IOU计算原理;

demo代码如下:

#include<iostream>
#include<QString>
#include<vector>
#include<QFile>
#include<QXmlStreamReader>

using namespace std;

typedef struct Bnd_Box_
{
    int xmin;
    int ymin;
    int xmax;
    int ymax;
}Bnd_Box;

typedef struct Obj_Box_
{
    int x;
    int y;
    int width;
    int height;
}Obj_Box;

typedef struct Xml_Image_Head_
{
    QString folder;
    QString filename;
    QString path;
    QString source;
    QString database;
    int image_width;
    int image_height;
    int image_depth;
    int image_size;
    int segmented;
}Xml_Image_Head;

typedef struct Xml_Box_
{
    QString class_name;
    QString pose;
    int truncated;
    int difficult;
    Bnd_Box box;
}Xml_Box;

void parseBlog( QXmlStreamReader &reader, Xml_Image_Head &image_head, vector< Xml_Box > &objs )
{
    bool ok;
    Xml_Box obj;
    QString strElementName;
    QString strCharacters;

    objs.clear();
    while( !reader.atEnd() )
    {
        reader.readNext();
        if (reader.isStartElement())
        {  // 开始元素
            strElementName = reader.name().toString();
            reader.readNext();
            if (reader.isCharacters())
            {
                strCharacters = reader.text().toString();
            //    qDebug() << QString("Characters : %1").arg(strCharacters);
            }
         //   QString strText = reader.text().toString();

           // qDebug() << QString::fromLocal8Bit("********** 开始元素<") << strElementName << QString::fromLocal8Bit("> ********** ") << strCharacters;
            if( QString( "folder" ) == strElementName )
            {
                image_head.folder = strCharacters;
            }
            else if( QString( "filename" ) == strElementName )
            {
                image_head.filename = strCharacters;
            }
            else if( QString( "path" ) == strElementName )
            {
                image_head.path = strCharacters;
            }
            else if( QString( "source" ) == strElementName )
            {
                image_head.source = strCharacters;
            }
            else if( QString( "database" ) == strElementName )
            {
                image_head.database = strCharacters;
            }
            else if( QString( "width" ) == strElementName )
            {
                image_head.image_width = strCharacters.toInt( &ok );
                if( false == ok )
                {
                    image_head.image_width = 0;
                }
            }
            else if( QString( "height" ) == strElementName )
            {
                image_head.image_height = strCharacters.toInt( &ok );
                if( false == ok )
                {
                    image_head.image_height = 0;
                }
            }
            else if( QString( "depth" ) == strElementName )
            {
                image_head.image_depth = strCharacters.toInt( &ok );
                if( false == ok )
                {
                    image_head.image_depth = 0;
                }
            }
            else if( QString( "segmented" ) == strElementName )
            {
                image_head.segmented = strCharacters.toInt( &ok );
                if( false == ok )
                {
                    image_head.segmented = 0;
                }
            }
            else if( QString( "object" ) == strElementName )
            {
                // TODO:: box
                bool is_do = true;
                while( is_do && !reader.atEnd() )
                {
                    reader.readNext();
                    if (reader.isStartElement())
                    {  // 开始元素
                        strElementName = reader.name().toString();
                        reader.readNext();
                        if (reader.isCharacters())
                        {
                            strCharacters = reader.text().toString();
                        }
                 //       qDebug() << QString::fromLocal8Bit("********** 开始元素<") << strElementName << QString::fromLocal8Bit("> ********** ") << strCharacters;
                        if( QString( "name" ) == strElementName )
                        {
                            obj.class_name = strCharacters;
                        }
                        else if( QString( "pose" ) == strElementName )
                        {
                            obj.pose = strCharacters;
                        }
                        else if( QString( "truncated" ) == strElementName )
                        {
                            obj.truncated = strCharacters.toInt( &ok );
                            if( false == ok )
                            {
                                obj.truncated = 0;
                            }
                        }
                        else if( QString( "difficult" ) == strElementName )
                        {
                            obj.difficult = strCharacters.toInt( &ok );
                            if( false == ok )
                            {
                                obj.difficult = 0;
                            }
                        }
                        else if( QString( "xmin" ) == strElementName )
                        {
                            obj.box.xmin = strCharacters.toInt( &ok );
                            if( false == ok )
                            {
                                obj.box.xmin = 0;
                            }
                        }
                        else if( QString( "ymin" ) == strElementName )
                        {
                            obj.box.ymin = strCharacters.toInt( &ok );
                            if( false == ok )
                            {
                                obj.box.ymin = 0;
                            }
                        }
                        else if( QString( "xmax" ) == strElementName )
                        {
                            obj.box.xmax = strCharacters.toInt( &ok );
                            if( false == ok )
                            {
                                obj.box.xmax = 0;
                            }
                        }
                        else if( QString( "ymax" ) == strElementName )
                        {
                            obj.box.ymax = strCharacters.toInt( &ok );
                            if( false == ok )
                            {
                                obj.box.ymax = 0;
                            }
                        }
                    }
                    else if ( reader.isEndElement() )
                    {  // 结束元素
                        QString strElementName = reader.name().toString();
                  //      qDebug() << QString::fromLocal8Bit("********** 结束元素<") << strElementName << QString::fromLocal8Bit("> ********** ");
                        if( 0 == strElementName.compare( "object" ) )
                        {
                            is_do = false;
                            objs.push_back( obj );
                        }
                    }
                }
            }
        }
        else if (reader.isEntityReference())
        {  // 实体引用
            QString strName = reader.name().toString();
            QString strText = reader.text().toString();
         //   qDebug() << QString("EntityReference : %1(%2)").arg(strName).arg(strText);
        }
        else if (reader.isCDATA())
        {  // CDATA
         //   QString strCDATA = reader.text().toString();
        //    qDebug() << QString("CDATA : %1").arg(strCDATA);

            reader.readNext();
            if (reader.isCharacters())
            {
                QString strCharacters = reader.text().toString();
        //        qDebug() << QString("Characters : %1").arg(strCharacters);
            }
        }
        else if (reader.isEndElement())
        {  // 结束元素
            QString strElementName = reader.name().toString();
       //     qDebug() << QString::fromLocal8Bit("********** 结束元素<") << strElementName << QString::fromLocal8Bit("> ********** ");
        }
        else if (reader.isDTD())
        {  // CDATA
            QString strDTD = reader.text().toString();
       //     qDebug() << QString("DTD : %1").arg(strDTD);

            reader.readNext();
            if (reader.isCharacters())
            {
                QString strCharacters = reader.text().toString();
        //        qDebug() << QString("Characters : %1").arg(strCharacters);
            }
        }
    }
}

void analysis_xml( QString xml_name, Xml_Image_Head &image_head, vector< Xml_Box > &box )
{

    QFile file_io( xml_name );
    file_io.open( QIODevice::ReadOnly );
    QXmlStreamReader xml_read( &file_io );

    if( file_io.isOpen() )
    {
        while( !xml_read.atEnd() )
        {
            QXmlStreamReader::TokenType nType = xml_read.readNext();
            switch( nType )
            {
                case QXmlStreamReader::StartDocument:
                {   // 开始文档
                //    qDebug() << QString::fromLocal8Bit( "********** 开始文档(XML 声明) ********** ");
                    // XML 声明
                    QString strVersion = xml_read.documentVersion().toString();
                    QString strEncoding = xml_read.documentEncoding().toString();
                    bool bAlone = xml_read.isStandaloneDocument();
                //    qDebug() << QString::fromLocal8Bit("版本:%1  编码:%2  Standalone:%3")
                 //               .arg(strVersion).arg(strEncoding).arg(bAlone) << "\r\n";
                    break;
                }
                case QXmlStreamReader::Comment:
                {   // 注释
              //      qDebug() << QString::fromLocal8Bit("********** 注释 ********** ");
                    QString strComment = xml_read.text().toString();
                //    qDebug() << strComment << "\r\n";
                    break;
                }
                case QXmlStreamReader::ProcessingInstruction:
                {   // 处理指令
               //     qDebug() << QString::fromLocal8Bit("********** 处理指令 ********** ");
                    QString strProcInstr = xml_read.processingInstructionData().toString();
               //     qDebug() << strProcInstr << "\r\n";
                    break;
                }
                case QXmlStreamReader::DTD:
                {   // DTD
                //    qDebug() << QString::fromLocal8Bit("********** DTD ********** ");
                    QString strDTD = xml_read.text().toString();
                //    QXmlStreamNotationDeclarations notationDeclarations = xml_read.notationDeclarations();  // 符号声明
                //    QXmlStreamEntityDeclarations entityDeclarations = xml_read.entityDeclarations();  // 实体声明
                    // DTD 声明
                    QString strDTDName = xml_read.dtdName().toString();
                    QString strDTDPublicId = xml_read.dtdPublicId().toString();  // DTD 公开标识符
                    QString strDTDSystemId = xml_read.dtdSystemId().toString();  // DTD 系统标识符
               //     qDebug() << QString::fromLocal8Bit("DTD : %1").arg(strDTD);
               //     qDebug() << QString::fromLocal8Bit("DTD 名称 : %1").arg(strDTDName);
               //     qDebug() << QString::fromLocal8Bit("DTD 公开标识符 : %1").arg(strDTDPublicId);
               //     qDebug() << QString::fromLocal8Bit("DTD 系统标识符 : %1").arg(strDTDSystemId);
               //     qDebug() << "\r\n";

                    break;
                }
                case QXmlStreamReader::StartElement:
                {   // 开始元素
                    QString strElementName = xml_read.name().toString();
                    if (QString::compare( strElementName, "annotation") == 0)
                    {   // 根元素
                   //     qDebug() << QString::fromLocal8Bit( "********** 开始元素<annotation> ********** " );
                        QXmlStreamAttributes attributes = xml_read.attributes();
                        if (attributes.hasAttribute("object"))
                        {
                            QString strVersion = attributes.value("Version").toString();
                //            qDebug() << QString::fromLocal8Bit("属性:Version(%1)").arg(strVersion);
                        }

                        parseBlog(xml_read, image_head, box );
                    }
                    break;
                }
                case QXmlStreamReader::EndDocument:
                {   // 结束文档
                //    qDebug() << QString::fromLocal8Bit("********** 结束文档 ********** ");
                    break;
                }

                default:
                    break;
            }
        }
    }
}
typedef struct Obj_Info_
{
    Obj_Box obj_box;
    float obj_prob;
    int obj_class;
}Obj_Info;

float calcIOU(int prediction_xmin,int prediction_ymin,int prediction_width,int prediction_height,
            int GT_xmin,        int GT_ymin,        int GT_width,        int GT_height)
{
    float IOU = 0.0;
    float eps = 1e-5;
    int center_point1_x = prediction_xmin + 0.5 * prediction_width;
    int center_point1_y = prediction_ymin + 0.5 * prediction_height;
    int center_point2_x = GT_xmin + 0.5 * GT_width;
    int center_point2_y = GT_ymin + 0.5 * GT_height;
    if (((abs(center_point1_x - center_point2_x) < ((prediction_width + GT_width) / 2.0)) and (
        abs(center_point1_y - center_point2_y) < ((prediction_height + GT_height) / 2.0))))
    {
        int lu_x_inter = max(prediction_xmin, GT_xmin);
        int lu_y_inter = max(prediction_ymin, GT_ymin);

        int rb_x_inter = min((prediction_xmin + prediction_width),
                             (GT_xmin + GT_width));
        int rb_y_inter = min((prediction_ymin + prediction_height),
                             (GT_ymin + GT_height));

        int inter_w = abs(rb_x_inter - lu_x_inter);
        int inter_h = abs(rb_y_inter - lu_y_inter);

        int inter_square = inter_w * inter_h;
        int union_square = (prediction_width * prediction_height) + (GT_width * GT_height) - inter_square;

        IOU = 1.0 * inter_square / union_square;
        cout<<"calcIOU:"<<IOU<<endl;
    }
    else
    {
        IOU = 0.0;
    }
    return IOU;
}
int main()
{
    Xml_Image_Head image_head;
    vector< Xml_Box > boxes;
    string img_file = "child_hug_vid1_cut_person_child_1_2_100000001.jpg";
    string root_path_xml ="/home/fuxueping/sdb/Caffe_Project/test_image_all_xml/" ;
    int i = img_file.find(".jpg");

    string xml_name(img_file.substr(0,i-1));

    string xml_file = root_path_xml+xml_name + ".xml";
    QString xml_dir = QString::fromStdString(xml_file);
    analysis_xml( xml_dir, image_head, boxes );
    int total = 0;
    int correct = 0;
    int proposal = 0;
    int out_obj_num;
    proposal += out_obj_num;
    total += boxes.size();
    Obj_Info obj_info[10];

    int num_box = 0;
    float iou_thresh = 0.5;
    float eps = 0.000000001;
    for(int i = 0; i < boxes.size();i++)
    {
        float best_iou = 0;
        for (int k=0; k < out_obj_num; k++ )
        {
            float w  = obj_info[ k ].obj_box.width;
            float h  = obj_info[ k ].obj_box.height;
            float xmin = obj_info[ k ].obj_box.x;
            float ymin = obj_info[ k ].obj_box.y;
            float iou = calcIOU(xmin,ymin,w,h,
                              boxes[i].box.xmin,
                              boxes[i].box.ymin,
                              boxes[i].box.xmax-boxes[i].box.xmin,
                              boxes[i].box.ymax - boxes[i].box.ymin);
            if(iou > best_iou)
            {
                best_iou = iou;
            }
        }
        if((best_iou > iou_thresh)&&(num_box < out_obj_num))
        {
            correct += 1;
            num_box += 1;
        }
    }
    float precision = 1.0 × correct/(proposal + eps);
    float recall = 1.0 * correct/(total + eps);
    float fscore = 2.0 * precision * recall /(precision + recall + eps);
    return 0;
}

xml类型如下:

<annotation>
  <folder>child_hug_vid1_cut</folder>
  <filename>2017-06-30_13-56-10_00686</filename>
  <path>/home/csh/Desktop/label/py/child_hug_vid1_cut/2017-06-30_13-56-10_00686.jpg</path>
  <source>
    <database>Unknown</database>
  </source>
  <size>
    <width>433</width>
    <height>229</height>
    <depth>3</depth>
  </size>
  <segmented>0</segmented>
  <object>
    <name>1</name>
    <pose>Unspecified</pose>
    <truncated>0</truncated>
    <difficult>0</difficult>
    <bndbox>
      <xmin>165</xmin>
      <ymin>46</ymin>
      <xmax>251</xmax>
      <ymax>193</ymax>
    </bndbox>
  </object>
  <object>
    <name>1</name>
    <pose>Unspecified</pose>
    <truncated>0</truncated>
    <difficult>0</difficult>
    <bndbox>
      <xmin>161</xmin>
      <ymin>22</ymin>
      <xmax>332</xmax>
      <ymax>202</ymax>
    </bndbox>
  </object>
</annotation>

这里是一个demo程序;

我在使用时,是通过图片的文件名然后获得xml的绝对路径,所以务必保证所有测试的图片是包含XML文件;

  • 2
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

猫猫与橙子

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值