TLD(Tracking-Learning-Detection)学习与源码理解

zouxy09@qq.com

TLD(Tracking-Learning-Detection)是英国萨里大学的一个捷克籍博士生Zdenek Kalal在其攻读博士学位期间提出的一种新的单目标长时间(long term tracking)跟踪算法。该算法与传统跟踪算法的显著区别在于将传统的跟踪算法和传统的检测算法相结合来解决被跟踪目标在被跟踪过程中发生的形变、部分遮挡等问题。同时,通过一种改进的在线学习机制不断更新跟踪模块的“显著特征点”和检测模块的目标模型及相关参数,从而使得跟踪效果更加稳定、鲁棒、可靠。

作者网站的链接http://info.ee.surrey.ac.uk/Personal/Z.Kalal/

其开放源代码,在网站上可以下载到源代码已经其demo程序,但是源代码是由MatlabC写的,对于我这种不懂Matlab的菜鸟来说,看代码就像天书;但很庆幸,有一个大牛已经用c++TLD重新写好了,而且代码很规范。并且提供源码下载:

https://github.com/arthurv/OpenTLD

源码为Linux版本,基于Opencv2.3 在源码/doc文件夹下有其程序设计接口,很清晰。

ZK关于这个TLD框架发表了很多论文,感觉对理解代码非常有用的论文有下面三个:

1Tracking-Learning-Detection

2Forward-Backward Error Automatic Detection of Tracking Failures

3Online learning of robust object detectors during unstable tracking

在作者的网站上好像也提供下载了。

另外自己学习的过程中,也搜到了不少大牛对TLD的分析,得到了很多帮助,具体有:

1)《庖丁解牛TLD》系列:

http://blog.csdn.net/yang_xian521/article/details/7091587

2)《再谈PN学习》:

http://blog.csdn.net/carson2005/article/details/7647519

3)《比微软kinect更强的视频跟踪算法--TLD跟踪算法介绍》

http://blog.csdn.net/carson2005/article/details/7647500

4)《TLD视觉跟踪技术解析》

http://www.asmag.com.cn/number/n-50168.shtml

 

自己在看论文和这些大牛的分析过程中,对代码进行了一些理解,但是由于自己接触图像处理和机器视觉没多久,另外由于自己编程能力比较弱,所以分析过程可能会有不少的错误,希望各位不吝指正。具体代码分析见博客的更新。

OpenTLD下载与编译:

1https://github.com/arthurv/OpenTLD
下载得到:arthurv-OpenTLD-1e3cd0b.zip

或者在Linux下直接通过git工具进行克隆:

#git clone git@github.com:alantrrs/OpenTLD.git

2)我的编译环境是Ubuntu 12.04 + Opencv2.3

安装opencv 2.3

#apt-get install libcv-dev libcv2.3 libcvaux-dev libcvaux2.3 libhighgui-dev libhighgui2.3

安装cmake

#sudo apt-get install cmake

解压然后按照源码目录下README文件进行编译:

#cd OpenTLD

#mkdir build

#cd build

#cmake ../src/

#make

#cd ../bin/

3)运行:

%To run from camera

./run_tld -p ../parameters.yml

%To run from file

./run_tld -p ../parameters.yml -s ../datasets/06_car/car.mpg

%To init bounding box from file

./run_tld -p ../parameters.yml -s ../datasets/06_car/car.mpg -b ../datasets/06_car/init.txt

%To train only in the firs frame (no tracking, no learning)

./run_tld -p ../parameters.yml -s ../datasets/06_car/car.mpg -b ../datasets/06_car/init.txt -no_tl

%To test the final detector (Repeat the video, first time learns, second time detects)

./run_tld -p ../parameters.yml -s ../datasets/06_car/car.mpg -b ../datasets/06_car/init.txt –r

下面是自己在看论文和这些大牛的分析过程中,对代码进行了一些理解,但是由于自己接触图像处理和机器视觉没多久,另外由于自己编程能力比较弱,所以分析过程可能会有不少的错误,希望各位不吝指正。而且,因为编程很多地方不懂,所以注释得非常乱,还海涵。

 

从main()函数切入,分析整个TLD运行过程如下:

(这里只是分析工作过程,全部注释的代码见博客的更新)

1、分析程序运行的命令行参数;

./run_tld -p ../parameters.yml -s ../datasets/06_car/car.mpg -b ../datasets/06_car/init.txt –r

 

2、读入初始化参数(程序中变量)的文件parameters.yml;

 

3、通过文件或者用户鼠标框选的方式指定要跟踪的目标的Bounding Box;

 

4、用上面得到的包含要跟踪目标的Bounding  Box和第一帧图像去初始化TLD系统,

   tld.init(last_gray, box, bb_file); 初始化包含的工作如下:

 

4.1、buildGrid(frame1, box);

检测器采用扫描窗口的策略:扫描窗口步长为宽高的 10%,尺度缩放系数为1.2;此函数构建全部的扫描窗口grid,并计算每一个扫描窗口与输入的目标box的重叠度;重叠度定义为两个box的交集与它们的并集的比;

 

4.2、为各种变量或者容器分配内存空间;

 

4.3、getOverlappingBoxes(box, num_closest_init);

此函数根据传入的box(目标边界框),在整帧图像中的全部扫描窗口中(由上面4.1得到)寻找与该box距离最小(即最相似,重叠度最大)的num_closest_init(10)个窗口,然后把这些窗口归入good_boxes容器。同时,把重叠度小于0.2的,归入bad_boxes容器;相当于对全部的扫描窗口进行筛选。并通过BBhull函数得到这些扫描窗口的最大边界。

   

4.5、classifier.prepare(scales);

准备分类器,scales容器里是所有扫描窗口的尺度,由上面的buildGrid()函数初始化;

TLD的分类器有三部分:方差分类器模块、集合分类器模块和最近邻分类器模块;这三个分类器是级联的,每一个扫描窗口依次全部通过上面三个分类器,才被认为含有前景目标。这里prepare这个函数主要是初始化集合分类器模块;

集合分类器(随机森林)基于n个基本分类器(共10棵树),每个分类器(树)都是基于一个pixel comparisons(共13个像素比较集)的,也就是说每棵树有13个判断节点(组成一个pixel comparisons),输入的图像片与每一个判断节点(相应像素点)进行比较,产生0或者1,然后将这13个0或者1连成一个13位的二进制码x(有2^13种可能),每一个x对应一个后验概率P(y|x)= #p/(#p+#n) (也有2^13种可能),#p和#n分别是正和负图像片的数目。那么整一个集合分类器(共10个基本分类器)就有10个后验概率了,将10个后验概率进行平均,如果大于阈值(一开始设经验值0.65,后面再训练优化)的话,就认为该图像片含有前景目标;

后验概率P(y|x)= #p/(#p+#n)的产生方法:初始化时,每个后验概率都得初始化为0;运行时候以下面方式更新:将已知类别标签的样本(训练样本)通过n个分类器进行分类,如果分类结果错误,那么相应的#p和#n就会更新,这样P(y|x)也相应更新了。

pixel comparisons的产生方法:先用一个归一化的patch去离散化像素空间,产生所有可能的垂直和水平的pixel comparisons,然后我们把这些pixel comparisons随机分配给n个分类器,每个分类器得到完全不同的pixel comparisons(特征集合),这样,所有分类器的特征组统一起来就可以覆盖整个patch了。

特征是相对于一种尺度的矩形框而言的,TLD中第s种尺度的第i个特征features[s][i] = Feature(x1, y1, x2, y2);是两个随机分配的像素点坐标(就是由这两个像素点比较得到0或者1的)。每一种尺度的扫描窗口都含有totalFeatures = nstructs * structSize个特征;nstructs为树木(由一个特征组构建,每组特征代表图像块的不同视图表示)的个数;structSize为每棵树的特征个数,也即每棵树的判断节点个数;树上每一个特征都作为一个决策节点;

prepare函数的工作就是先给每一个扫描窗口初始化了对应的pixel comparisons(两个随机分配的像素点坐标);然后初始化后验概率为0;

 

4.6、generatePositiveData(frame1, num_warps_init);

此函数通过对第一帧图像的目标框box(用户指定的要跟踪的目标)进行仿射变换来合成训练初始分类器的正样本集。具体方法如下:先在距离初始的目标框最近的扫描窗口内选择10个bounding box(已经由上面的getOverlappingBoxes函数得到,存于good_boxes里面了,还记得不?),然后在每个bounding box的内部,进行±1%范围的偏移,±1%范围的尺度变化,±10%范围的平面内旋转,并且在每个像素上增加方差为5的高斯噪声(确切的大小是在指定的范围内随机选择的),那么每个box都进行20次这种几何变换,那么10个box将产生200个仿射变换的bounding box,作为正样本。具体实现如下:

getPattern(frame(best_box), pEx, mean, stdev);此函数将frame图像best_box区域的图像片归一化为均值为0的15*15大小的patch,存于pEx(用于最近邻分类器的正样本)正样本中(最近邻的box的Pattern),该正样本只有一个。

generator(frame, pt, warped, bbhull.size(), rng);此函数属于PatchGenerator类的构造函数,用来对图像区域进行仿射变换,先RNG一个随机因子,再调用()运算符产生一个变换后的正样本。

classifier.getFeatures(patch, grid[idx].sidx, fern);函数得到输入的patch的特征fern(13位的二进制代码);

pX.push_back(make_pair(fern, 1));   //positive ferns <features, labels=1>然后标记为正样本,存入pX(用于集合分类器的正样本)正样本库;

以上的操作会循环 num_warps * good_boxes.size()即20 * 10 次,这样,pEx就有了一个正样本,而pX有了200个正样本了;

 

4.7、meanStdDev(frame1(best_box), mean, stdev);

统计best_box的均值和标准差,var = pow(stdev.val[0],2) * 0.5;作为方差分类器的阈值。

 

4.8、generateNegativeData(frame1);

     由于TLD仅跟踪一个目标,所以我们确定了目标框了,故除目标框外的其他图像都是负样本,无需仿射变换;具体实现如下:

     由于之前重叠度小于0.2的,都归入 bad_boxes了,所以数量挺多,把方差大于var*0.5f的bad_boxes都加入负样本,同上面一样,需要classifier.getFeatures(patch, grid[idx].sidx, fern);和nX.push_back(make_pair(fern, 0));得到对应的fern特征和标签的nX负样本(用于集合分类器的负样本);

    然后随机在上面的bad_boxes中取bad_patches(100个)个box,然后用 getPattern函数将frame图像bad_box区域的图像片归一化到15*15大小的patch,存在nEx(用于最近邻分类器的负样本)负样本中。

这样nEx和nX都有负样本了;(box的方差通过积分图像计算)

 

4.9、然后将nEx的一半作为训练集nEx,另一半作为测试集nExT;同样,nX也拆分为训练集nX和测试集nXT;

 

4.10、将负样本nX和正样本pX合并到ferns_data[]中,用于集合分类器的训练;

 

4.11、将上面得到的一个正样本pEx和nEx合并到nn_data[]中,用于最近邻分类器的训练;

 

4.12、用上面的样本训练集训练 集合分类器(森林) 和 最近邻分类器:

  classifier.trainF(ferns_data, 2); //bootstrap = 2

对每一个样本ferns_data[i] ,如果样本是正样本标签,先用measure_forest函数返回该样本所有树的所有特征值对应的后验概率累加值,该累加值如果小于正样本阈值(0.6* nstructs,这就表示平均值需要大于0.6(0.6* nstructs / nstructs),0.6是程序初始化时定的集合分类器的阈值,为经验值,后面会用测试集来评估修改,找到最优),也就是输入的是正样本,却被分类成负样本了,出现了分类错误,所以就把该样本添加到正样本库,同时用update函数更新后验概率。对于负样本,同样,如果出现负样本分类错误,就添加到负样本库。

  classifier.trainNN(nn_data);

     对每一个样本nn_data,如果标签是正样本,通过NNConf(nn_examples[i], isin, conf, dummy);计算输入图像片与在线模型之间的相关相似度conf,如果相关相似度小于0.65 ,则认为其不含有前景目标,也就是分类错误了;这时候就把它加到正样本库。然后就通过pEx.push_back(nn_examples[i]);将该样本添加到pEx正样本库中;同样,如果出现负样本分类错误,就添加到负样本库。

 

4.13、用测试集在上面得到的 集合分类器(森林) 和 最近邻分类器中分类,评价并修改得到最好的分类器阈值。

  classifier.evaluateTh(nXT, nExT);

   对集合分类器,对每一个测试集nXT,所有基本分类器的后验概率的平均值如果大于thr_fern(0.6),则认为含有前景目标,然后取最大的平均值(大于thr_fern)作为该集合分类器的新的阈值。

   对最近邻分类器,对每一个测试集nExT,最大相关相似度如果大于nn_fern(0.65),则认为含有前景目标,然后取最大的最大相关相似度(大于nn_fern)作为该最近邻分类器的新的阈值。

 

5、进入一个循环:读入新的一帧,然后转换为灰度图像,然后再处理每一帧processFrame;

 

6、tld.processFrame(last_gray, current_gray, pts1, pts2, pbox, status, tl, bb_file);逐帧读入图片序列,进行算法处理。processFrame共包含四个模块(依次处理):跟踪模块、检测模块、综合模块和学习模块;

 

6.1、跟踪模块:track(img1, img2, points1, points2);

track函数完成前一帧img1的特征点points1到当前帧img2的特征点points2的跟踪预测;

 

6.1.1、具体实现过程如下:

(1)先在lastbox中均匀采样10*10=100个特征点(网格均匀撒点),存于points1:

bbPoints(points1, lastbox);

(2)利用金字塔LK光流法跟踪这些特征点,并预测当前帧的特征点(见下面的解释)、计算FB error和匹配相似度sim,然后筛选出 FB_error[i] <= median(FB_error) 和 sim_error[i] > median(sim_error) 的特征点(舍弃跟踪结果不好的特征点),剩下的是不到50%的特征点

tracker.trackf2f(img1, img2, points, points2);

(3)利用剩下的这不到一半的跟踪点输入来预测bounding box在当前帧的位置和大小 tbb:

bbPredict(points, points2, lastbox, tbb);

(4)跟踪失败检测:如果FB error的中值大于10个像素(经验值),或者预测到的当前box的位置移出图像,则认为跟踪错误,此时不返回bounding box:

if (tracker.getFB()>10 || tbb.x>img2.cols ||  tbb.y>img2.rows || tbb.br().x < 1 || tbb.br().y <1)

(5)归一化img2(bb)对应的patch的size(放缩至patch_size = 15*15),存入pattern:

getPattern(img2(bb),pattern,mean,stdev);

(6)计算图像片pattern到在线模型M的保守相似度:

classifier.NNConf(pattern,isin,dummy,tconf);

(7)如果保守相似度大于阈值,则评估本次跟踪有效,否则跟踪无效:

if (tconf>classifier.thr_nn_valid) tvalid =true;

 

6.1.2、TLD跟踪模块的实现原理和trackf2f函数的实现:

   TLD跟踪模块的实现是利用了Media Flow 中值光流跟踪和跟踪错误检测算法的结合。中值流跟踪方法是基于Forward-Backward Error和NNC的。原理很简单:从t时刻的图像的A点,跟踪到t+1时刻的图像B点;然后倒回来,从t+1时刻的图像的B点往回跟踪,假如跟踪到t时刻的图像的C点,这样就产生了前向和后向两个轨迹,比较t时刻中 A点和C点的距离,如果距离小于一个阈值,那么就认为前向跟踪是正确的;这个距离就是FB_error;

bool LKTracker::trackf2f(const Mat& img1, const Mat& img2, vector<Point2f> &points1, vector<cv::Point2f> &points2)

函数实现过程如下:

(1)先利用金字塔LK光流法跟踪预测前向轨迹:

  calcOpticalFlowPyrLK( img1,img2, points1, points2, status, similarity, window_size, level, term_criteria, lambda, 0);

(2)再往回跟踪,产生后向轨迹:

  calcOpticalFlowPyrLK( img2,img1, points2, pointsFB, FB_status,FB_error, window_size, level, term_criteria, lambda, 0);

(3)然后计算 FB-error:前向与 后向 轨迹的误差:

  for( int i= 0; i<points1.size(); ++i )

        FB_error[i] = norm(pointsFB[i]-points1[i]);     

(4)再从前一帧和当前帧图像中(以每个特征点为中心)使用亚象素精度提取10x10象素矩形(使用函数getRectSubPix得到),匹配前一帧和当前帧中提取的10x10象素矩形,得到匹配后的映射图像(调用matchTemplate),得到每一个点的NCC相关系数(也就是相似度大小)。

normCrossCorrelation(img1, img2, points1, points2);

(5)然后筛选出 FB_error[i] <= median(FB_error) 和 sim_error[i] > median(sim_error) 的特征点(舍弃跟踪结果不好的特征点),剩下的是不到50%的特征点;

filterPts(points1, points2);

 

6.2、检测模块:detect(img2);

TLD的检测分类器有三部分:方差分类器模块、集合分类器模块和最近邻分类器模块;这三个分类器是级联的。当前帧img2的每一个扫描窗口依次通过上面三个分类器,全部通过才被认为含有前景目标。具体实现过程如下:

先计算img2的积分图,为了更快的计算方差:

integral(frame,iisum,iisqsum);

然后用高斯模糊,去噪:

  GaussianBlur(frame,img,Size(9,9),1.5); 

下一步就进入了方差检测模块:

 

6.2.1、方差分类器模块:getVar(grid[i],iisum,iisqsum) >= var

利用积分图计算每个待检测窗口的方差,方差大于var阈值(目标patch方差的50%)的,则认为其含有前景目标,通过该模块的进入集合分类器模块:

 

6.2.2、集合分类器模块:

集合分类器(随机森林)共有10颗树(基本分类器),每棵树13个判断节点,每个判断节点经比较得到一个二进制位0或者1,这样每棵树就对应得到一个13位的二进制码x(叶子),这个二进制码x对应于一个后验概率P(y|x)。那么整一个集合分类器(共10个基本分类器)就有10个后验概率了,将10个后验概率进行平均,如果大于阈值(一开始设经验值0.65,后面再训练优化)的话,就认为该图像片含有前景目标;具体过程如下:

(1)先得到该patch的特征值(13位的二进制代码):

classifier.getFeatures(patch,grid[i].sidx,ferns);

(2)再计算该特征值对应的后验概率累加值:

conf = classifier.measure_forest(ferns);           

(3)若集合分类器的后验概率的平均值大于阈值fern_th(由训练得到),就认为含有前景目标:

if (conf > numtrees * fern_th)  dt.bb.push_back(i); 

(4)将通过以上两个检测模块的扫描窗口记录在detect structure中;

(5)如果顺利通过以上两个检测模块的扫描窗口数大于100个,则只取后验概率大的前100个;

nth_element(dt.bb.begin(), dt.bb.begin()+100, dt.bb.end(),

CComparator(tmp.conf));

进入最近邻分类器:

 

6.2.3、最近邻分类器模块

(1)先归一化patch的size(放缩至patch_size = 15*15),存入dt.patch[i];

getPattern(patch,dt.patch[i],mean,stdev); 

(2)计算图像片pattern到在线模型M的相关相似度和保守相似度:

classifier.NNConf(dt.patch[i],dt.isin[i],dt.conf1[i],dt.conf2[i]);

(3)相关相似度大于阈值,则认为含有前景目标:

if (dt.conf1[i]>nn_th)  dbb.push_back(grid[idx]);

到目前为止,检测器检测完成,全部通过三个检测模块的扫描窗口存在dbb中;

 

6.3、综合模块:

TLD只跟踪单目标,所以综合模块综合跟踪器跟踪到的单个目标和检测器可能检测到的多个目标,然后只输出保守相似度最大的一个目标。具体实现过程如下:

(1)先通过 重叠度 对检测器检测到的目标bounding box进行聚类,每个类的重叠度小于0.5:

clusterConf(dbb, dconf, cbb, cconf);

(2)再找到与跟踪器跟踪到的box距离比较远的类(检测器检测到的box),而且它的相关相似度比跟踪器的要大:记录满足上述条件,也就是可信度比较高的目标box的个数:

if (bbOverlap(tbb, cbb[i])<0.5 && cconf[i]>tconf) confident_detections++;

(3)判断如果只有一个满足上述条件的box,那么就用这个目标box来重新初始化跟踪器(也就是用检测器的结果去纠正跟踪器):

if (confident_detections==1)  bbnext=cbb[didx];

(4)如果满足上述条件的box不只一个,那么就找到检测器检测到的box与跟踪器预测到的box距离很近(重叠度大于0.7)的所以box,对其坐标和大小进行累加:

if(bbOverlap(tbb,dbb[i])>0.7)  cx += dbb[i].x;……

(5)对与跟踪器预测到的box距离很近的box 和 跟踪器本身预测到的box 进行坐标与大小的平均作为最终的目标bounding box,但是跟踪器的权值较大:

bbnext.x = cvRound((float)(10*tbb.x+cx)/(float)(10+close_detections));……

(6)另外,如果跟踪器没有跟踪到目标,但是检测器检测到了一些可能的目标box,那么同样对其进行聚类,但只是简单的将聚类的cbb[0]作为新的跟踪目标box(不比较相似度了??还是里面已经排好序了??),重新初始化跟踪器:

bbnext=cbb[0];

至此,综合模块结束。

 

6.4、学习模块:learn(img2);

    学习模块也分为如下四部分:

6.4.1、检查一致性:

(1)归一化img(bb)对应的patch的size(放缩至patch_size = 15*15),存入pattern:

  getPattern(img(bb), pattern, mean, stdev);

(2)计算输入图像片(跟踪器的目标box)与在线模型之间的相关相似度conf:

  classifier.NNConf(pattern,isin,conf,dummy);

(3)如果相似度太小了或者如果方差太小了或者如果被被识别为负样本,那么就不训练了;

if (conf<0.5)……或if (pow(stdev.val[0], 2)< var)……或if(isin[2]==1)……

 

6.4.2、生成样本:

先是集合分类器的样本:fern_examples:

(1)先计算所有的扫描窗口与目前的目标box的重叠度:

grid[i].overlap = bbOverlap(lastbox, grid[i]);

(2)再根据传入的lastbox,在整帧图像中的全部窗口中寻找与该lastbox距离最小(即最相似,重叠度最大)的num_closest_update个窗口,然后把这些窗口归入good_boxes容器(只是把网格数组的索引存入)同时,把重叠度小于0.2的,归入 bad_boxes 容器:

  getOverlappingBoxes(lastbox, num_closest_update);

(3)然后用仿射模型产生正样本(类似于第一帧的方法,但只产生10*10=100个):

generatePositiveData(img, num_warps_update); 

(4)加入负样本,相似度大于1??相似度不是出于0和1之间吗?

idx=bad_boxes[i];

if (tmp.conf[idx]>=1) fern_examples.push_back(make_pair(tmp.patt[idx],0));

然后是最近邻分类器的样本:nn_examples:

if (bbOverlap(lastbox,grid[idx]) < bad_overlap)

        nn_examples.push_back(dt.patch[i]);

 

6.4.3、分类器训练:

classifier.trainF(fern_examples,2);

classifier.trainNN(nn_examples);

 

6.4.4、把正样本库(在线模型)包含的所有正样本显示在窗口上

classifier.show();

至此,tld.processFrame函数结束。

 

7、如果跟踪成功,则把相应的点和box画出来:

    if (status){

      drawPoints(frame,pts1);

      drawPoints(frame,pts2,Scalar(0,255,0));  //当前的特征点用蓝色点表示

      drawBox(frame,pbox);

      detections++;

}

 

8、然后显示窗口和交换图像帧,进入下一帧的处理:

    imshow("TLD", frame);

swap(last_gray, current_gray);

至此,main()函数结束(只分析了框架)。

下面是自己在看论文和这些大牛的分析过程中,对代码进行了一些理解,但是由于自己接触图像处理和机器视觉没多久,另外由于自己编程能力比较弱,所以分析过程可能会有不少的错误,希望各位不吝指正。而且,因为编程很多地方不懂,所以注释得非常乱,还海涵。

run_tld.cpp

[cpp]  view plain copy
  1. #include <opencv2/opencv.hpp>  
  2. #include <tld_utils.h>  
  3. #include <iostream>  
  4. #include <sstream>  //c++中的sstream类,提供了程序和string对象之间的I/O,可以通过ostringstream  
  5.                     //和instringstream两个类来声明对象,分别对应输出流和输入流  
  6. #include <TLD.h>  
  7. #include <stdio.h>  
  8. using namespace cv;  
  9. using namespace std;  
  10. //Global variables  
  11. Rect box;  
  12. bool drawing_box = false;  
  13. bool gotBB = false;  
  14. bool tl = true;  
  15. bool rep = false;  
  16. bool fromfile=false;  
  17. string video;  
  18.   
  19. //读取记录bounding box的文件,获得bounding box的四个参数:左上角坐标x,y和宽高  
  20. /*如在\datasets\06_car\init.txt中:记录了初始目标的bounding box,内容如下 
  21. 142,125,232,164    
  22. */  
  23. void readBB(char* file){  
  24.   ifstream bb_file (file);  //以输入方式打开文件  
  25.   string line;  
  26.   //istream& getline ( istream& , string& );  
  27.   //将输入流is中读到的字符存入str中,终结符默认为 '\n'(换行符)   
  28.   getline(bb_file, line);  
  29.   istringstream linestream(line); //istringstream对象可以绑定一行字符串,然后以空格为分隔符把该行分隔开来。  
  30.   string x1,y1,x2,y2;  
  31.     
  32.   //istream& getline ( istream &is , string &str , char delim );   
  33.   //将输入流is中读到的字符存入str中,直到遇到终结符delim才结束。  
  34.   getline (linestream,x1, ',');  
  35.   getline (linestream,y1, ',');  
  36.   getline (linestream,x2, ',');  
  37.   getline (linestream,y2, ',');  
  38.     
  39.   //atoi 功 能: 把字符串转换成整型数  
  40.   int x = atoi(x1.c_str());// = (int)file["bb_x"];  
  41.   int y = atoi(y1.c_str());// = (int)file["bb_y"];  
  42.   int w = atoi(x2.c_str())-x;// = (int)file["bb_w"];  
  43.   int h = atoi(y2.c_str())-y;// = (int)file["bb_h"];  
  44.   box = Rect(x,y,w,h);  
  45. }  
  46.   
  47. //bounding box mouse callback  
  48. //鼠标的响应就是得到目标区域的范围,用鼠标选中bounding box。  
  49. void mouseHandler(int event, int x, int y, int flags, void *param){  
  50.   switch( event ){  
  51.   case CV_EVENT_MOUSEMOVE:  
  52.     if (drawing_box){  
  53.         box.width = x-box.x;  
  54.         box.height = y-box.y;  
  55.     }  
  56.     break;  
  57.   case CV_EVENT_LBUTTONDOWN:  
  58.     drawing_box = true;  
  59.     box = Rect( x, y, 0, 0 );  
  60.     break;  
  61.   case CV_EVENT_LBUTTONUP:  
  62.     drawing_box = false;  
  63.     if( box.width < 0 ){  
  64.         box.x += box.width;  
  65.         box.width *= -1;  
  66.     }  
  67.     if( box.height < 0 ){  
  68.         box.y += box.height;  
  69.         box.height *= -1;  
  70.     }  
  71.     gotBB = true;   //已经获得bounding box  
  72.     break;  
  73.   }  
  74. }  
  75.   
  76. void print_help(char** argv){  
  77.   printf("use:\n     %s -p /path/parameters.yml\n",argv[0]);  
  78.   printf("-s    source video\n-b        bounding box file\n-tl  track and learn\n-r     repeat\n");  
  79. }  
  80.   
  81. //分析运行程序时的命令行参数  
  82. void read_options(int argc, char** argv, VideoCapture& capture, FileStorage &fs){  
  83.   for (int i=0;i<argc;i++){  
  84.       if (strcmp(argv[i],"-b")==0){  
  85.           if (argc>i){  
  86.               readBB(argv[i+1]);  //是否指定初始的bounding box  
  87.               gotBB = true;  
  88.           }  
  89.           else  
  90.             print_help(argv);  
  91.       }  
  92.       if (strcmp(argv[i],"-s")==0){   //从视频文件中读取  
  93.           if (argc>i){  
  94.               video = string(argv[i+1]);  
  95.               capture.open(video);  
  96.               fromfile = true;  
  97.           }  
  98.           else  
  99.             print_help(argv);  
  100.   
  101.       }  
  102.       //Similar in format to XML, Yahoo! Markup Language (YML) provides functionality to Open   
  103.       //Applications in a safe and standardized fashion. You include YML tags in the HTML code  
  104.       //of an Open Application.  
  105.       if (strcmp(argv[i],"-p")==0){   //读取参数文件parameters.yml  
  106.           if (argc>i){  
  107.           //FileStorage类的读取方式可以是:FileStorage fs(".\\parameters.yml", FileStorage::READ);    
  108.               fs.open(argv[i+1], FileStorage::READ);  
  109.           }  
  110.           else  
  111.             print_help(argv);  
  112.       }  
  113.       if (strcmp(argv[i],"-no_tl")==0){  //To train only in the first frame (no tracking, no learning)  
  114.           tl = false;  
  115.       }  
  116.       if (strcmp(argv[i],"-r")==0){  //Repeat the video, first time learns, second time detects  
  117.           rep = true;  
  118.       }  
  119.   }  
  120. }  
  121.   
  122. /* 
  123. 运行程序时: 
  124. %To run from camera 
  125. ./run_tld -p ../parameters.yml 
  126. %To run from file 
  127. ./run_tld -p ../parameters.yml -s ../datasets/06_car/car.mpg 
  128. %To init bounding box from file 
  129. ./run_tld -p ../parameters.yml -s ../datasets/06_car/car.mpg -b ../datasets/06_car/init.txt 
  130. %To train only in the first frame (no tracking, no learning) 
  131. ./run_tld -p ../parameters.yml -s ../datasets/06_car/car.mpg -b ../datasets/06_car/init.txt -no_tl  
  132. %To test the final detector (Repeat the video, first time learns, second time detects) 
  133. ./run_tld -p ../parameters.yml -s ../datasets/06_car/car.mpg -b ../datasets/06_car/init.txt -r 
  134. */  
  135. //感觉就是对起始帧进行初始化工作,然后逐帧读入图片序列,进行算法处理。  
  136. int main(int argc, char * argv[]){  
  137.   VideoCapture capture;  
  138.   capture.open(0);  
  139.     
  140.   //OpenCV的C++接口中,用于保存图像的imwrite只能保存整数数据,且需作为图像格式。当需要保存浮  
  141.   //点数据或XML/YML文件时,OpenCV的C语言接口提供了cvSave函数,但这一函数在C++接口中已经被删除。  
  142.   //取而代之的是FileStorage类。  
  143.   FileStorage fs;  
  144.   //Read options  
  145.   read_options(argc, argv, capture, fs);  //分析命令行参数  
  146.   //Init camera  
  147.   if (!capture.isOpened())  
  148.   {  
  149.     cout << "capture device failed to open!" << endl;  
  150.     return 1;  
  151.   }  
  152.   //Register mouse callback to draw the bounding box  
  153.   cvNamedWindow("TLD",CV_WINDOW_AUTOSIZE);  
  154.   cvSetMouseCallback( "TLD", mouseHandler, NULL );  //用鼠标选中初始目标的bounding box  
  155.   //TLD framework  
  156.   TLD tld;  
  157.   //Read parameters file  
  158.   tld.read(fs.getFirstTopLevelNode());  
  159.   Mat frame;  
  160.   Mat last_gray;  
  161.   Mat first;  
  162.   if (fromfile){  //如果指定为从文件读取  
  163.       capture >> frame;   //读当前帧  
  164.       cvtColor(frame, last_gray, CV_RGB2GRAY);  //转换为灰度图像  
  165.       frame.copyTo(first);  //拷贝作为第一帧  
  166.   }else{   //如果为读取摄像头,则设置获取的图像大小为320x240   
  167.       capture.set(CV_CAP_PROP_FRAME_WIDTH,340);  //340??  
  168.       capture.set(CV_CAP_PROP_FRAME_HEIGHT,240);  
  169.   }  
  170.   
  171.   ///Initialization  
  172. GETBOUNDINGBOX:   //标号:获取bounding box  
  173.   while(!gotBB)  
  174.   {  
  175.     if (!fromfile){  
  176.       capture >> frame;  
  177.     }  
  178.     else  
  179.       first.copyTo(frame);  
  180.     cvtColor(frame, last_gray, CV_RGB2GRAY);  
  181.     drawBox(frame,box);  //把bounding box 画出来  
  182.     imshow("TLD", frame);  
  183.     if (cvWaitKey(33) == 'q')  
  184.         return 0;  
  185.   }  
  186.   //由于图像片(min_win 为15x15像素)是在bounding box中采样得到的,所以box必须比min_win要大  
  187.   if (min(box.width, box.height)<(int)fs.getFirstTopLevelNode()["min_win"]){  
  188.       cout << "Bounding box too small, try again." << endl;  
  189.       gotBB = false;  
  190.       goto GETBOUNDINGBOX;  
  191.   }  
  192.   //Remove callback  
  193.   cvSetMouseCallback( "TLD", NULL, NULL );  //如果已经获得第一帧用户框定的box了,就取消鼠标响应  
  194.   printf("Initial Bounding Box = x:%d y:%d h:%d w:%d\n",box.x,box.y,box.width,box.height);  
  195.   //Output file  
  196.   FILE  *bb_file = fopen("bounding_boxes.txt","w");  
  197.     
  198.   //TLD initialization  
  199.   tld.init(last_gray, box, bb_file);  
  200.   
  201.   ///Run-time  
  202.   Mat current_gray;  
  203.   BoundingBox pbox;  
  204.   vector<Point2f> pts1;  
  205.   vector<Point2f> pts2;  
  206.   bool status=true;  //记录跟踪成功与否的状态 lastbox been found  
  207.   int frames = 1;  //记录已过去帧数  
  208.   int detections = 1;  //记录成功检测到的目标box数目  
  209.     
  210. REPEAT:  
  211.   while(capture.read(frame)){  
  212.     //get frame  
  213.     cvtColor(frame, current_gray, CV_RGB2GRAY);  
  214.     //Process Frame  
  215.     tld.processFrame(last_gray, current_gray, pts1, pts2, pbox, status, tl, bb_file);  
  216.     //Draw Points  
  217.     if (status){  //如果跟踪成功  
  218.       drawPoints(frame,pts1);  
  219.       drawPoints(frame,pts2,Scalar(0,255,0));  //当前的特征点用蓝色点表示  
  220.       drawBox(frame,pbox);  
  221.       detections++;  
  222.     }  
  223.     //Display  
  224.     imshow("TLD", frame);  
  225.     //swap points and images  
  226.     swap(last_gray, current_gray);  //STL函数swap()用来交换两对象的值。其泛型化版本定义于<algorithm>;  
  227.     pts1.clear();  
  228.     pts2.clear();  
  229.     frames++;  
  230.     printf("Detection rate: %d/%d\n", detections, frames);  
  231.     if (cvWaitKey(33) == 'q')  
  232.       break;  
  233.   }  
  234.   if (rep){  
  235.     rep = false;  
  236.     tl = false;  
  237.     fclose(bb_file);  
  238.     bb_file = fopen("final_detector.txt","w");  
  239.     //capture.set(CV_CAP_PROP_POS_AVI_RATIO,0);  
  240.     capture.release();  
  241.     capture.open(video);  
  242.     goto REPEAT;  
  243.   }  
  244.   fclose(bb_file);  
  245.   return 0;  
  246. }  

 

tld_utils.cpp

[cpp]  view plain copy
  1. #include <tld_utils.h>  
  2. using namespace cv;  
  3. using namespace std;  
  4.   
  5. /*vector是C++标准模板库STL中的部分内容,它是一个多功能的,能够操作多种数据结构和算法的 
  6. 模板类和函数库。vector之所以被认为是一个容器,是因为它能够像容器一样存放各种类型的对象, 
  7. 简单地说,vector是一个能够存放任意类型的动态数组,能够增加和压缩数据。 
  8. 为了可以使用vector,必须在你的头文件中包含下面的代码: 
  9. #include <vector> 
  10. vector属于std命名域的,因此需要通过命名限定,如下完成你的代码: 
  11. using std::vector; 
  12. */  
  13.   
  14. void drawBox(Mat& image, CvRect box, Scalar color, int thick){  
  15.   rectangle( image, cvPoint(box.x, box.y), cvPoint(box.x+box.width,box.y+box.height),color, thick);  
  16. }   
  17.   
  18. //函数 cvRound, cvFloor, cvCeil 用一种舍入方法将输入浮点数转换成整数。  
  19. //cvRound 返回和参数最接近的整数值。 cvFloor 返回不大于参数的最大整数值。  
  20. //cvCeil 返回不小于参数的最小整数值。  
  21. void drawPoints(Mat& image, vector<Point2f> points,Scalar color){  
  22.   for( vector<Point2f>::const_iterator i = points.begin(), ie = points.end(); i != ie; ++i )  
  23.       {  
  24.       Point center( cvRound(i->x ), cvRound(i->y));  //类似于int i(3)的初始化,但center为何没用到?  
  25.       circle(image,*i,2,color,1);  
  26.       }  
  27. }  
  28.   
  29. Mat createMask(const Mat& image, CvRect box){  
  30.   Mat mask = Mat::zeros(image.rows,image.cols,CV_8U);  
  31.   drawBox(mask,box,Scalar::all(255),CV_FILLED);  
  32.   return mask;  
  33. }  
  34.   
  35. //STL中的nth_element()方法找出一个数列中排名第n的那个数。  
  36. //对于序列a[0:len-1]将第n大的数字,排在a[n],同时a[0:n-1]都小于a[n],a[n+1:]都大于a[n],  
  37. //但a[n]左右的这两个序列不一定有序。  
  38. //用在中值流跟踪算法中,寻找中值  
  39. float median(vector<float> v)  
  40. {  
  41.     int n = floor(v.size() / 2);  
  42.     nth_element(v.begin(), v.begin()+n, v.end());  
  43.     return v[n];  
  44. }  
  45.   
  46. //<algorithm> //random_shuffle的头文件  
  47. //shuffle 洗牌  首先简单的介绍一个扑克牌洗牌的方法,假设一个数组 poker[52] 中存有一副扑克  
  48. //牌1-52的牌点值,使用一个for循环遍历这个数组,每次循环都生成一个[0,52)之间的随机数RandNum,  
  49. //以RandNum为数组下标,把当前下标对应的值和RandNum对应位置的值交换,循环结束,每个牌都与某个  
  50. //位置交换了一次,这样一副牌就被打乱了。 理解代码如下:  
  51. /* 
  52. for (int i = 0; i < 52; ++i)   
  53. {   
  54.     int RandNum = rand() % 52;     
  55.     int tmp = poker[i];   
  56.     poker[i] = poker[RandNum];   
  57.     poker[RandNum] = tmp;   
  58.  
  59. */  
  60. //需要指定范围内的随机数,传统的方法是使用ANSI C的函数random(),然后格式化结果以便结果是落在  
  61. //指定的范围内。但是,使用这个方法至少有两个缺点。做格式化时,结果常常是扭曲的,且只支持整型数。  
  62. //C++中提供了更好的解决方法,那就是STL中的random_shuffle()算法。产生指定范围内的随机元素集的最佳方法  
  63. //是创建一个顺序序列(也就是向量或者内置数组),在这个顺序序列中含有指定范围的所有值。  
  64. //例如,如果你需要产生100个0-99之间的数,那么就创建一个向量并用100个按升序排列的数填充向量.  
  65. //填充完向量之后,用random_shuffle()算法打乱元素排列顺序。  
  66. //默认的random_shuffle中, 被操作序列的index 与 rand() % N 两个位置的值交换,来达到乱序的目的。  
  67. //index_shuffle()用于产生指定范围[begin:end]的随机数,返回随机数数组  
  68. vector<int> index_shuffle(int begin,int end){  
  69.   vector<int> indexes(end-begin);  
  70.   for (int i=begin;i<end;i++){  
  71.     indexes[i]=i;  
  72.   }  
  73.   random_shuffle(indexes.begin(),indexes.end());  
  74.   return indexes;  
  75. }  

 下面是自己在看论文和这些大牛的分析过程中,对代码进行了一些理解,但是由于自己接触图像处理和机器视觉没多久,另外由于自己编程能力比较弱,所以分析过程可能会有不少的错误,希望各位不吝指正。而且,因为编程很多地方不懂,所以注释得非常乱,还海涵。

LKTracker.h

[cpp]  view plain copy
  1. #include<tld_utils.h>  
  2. #include <opencv2/opencv.hpp>  
  3.   
  4. //使用金字塔LK光流法跟踪,所以类的成员变量很多都是OpenCV中calcOpticalFlowPyrLK()函数的参数  
  5. class LKTracker{  
  6. private:  
  7.   std::vector<cv::Point2f> pointsFB;  
  8.   cv::Size window_size;  //每个金字塔层的搜索窗口尺寸  
  9.   int level;            //最大的金字塔层数  
  10.   std::vector<uchar> status;   //数组。如果对应特征的光流被发现,数组中的每一个元素都被设置为 1, 否则设置为 0  
  11.   std::vector<uchar> FB_status;     
  12.   std::vector<float> similarity;  //相似度  
  13.   std::vector<float> FB_error;   //Forward-Backward error方法,求FB_error的结果与原始位置的欧式距离  
  14.                                  //做比较,把距离过大的跟踪结果舍弃  
  15.   float simmed;  
  16.   float fbmed;  
  17.   //TermCriteria模板类,取代了之前的CvTermCriteria,这个类是作为迭代算法的终止条件的  
  18.   //该类变量需要3个参数,一个是类型,第二个参数为迭代的最大次数,最后一个是特定的阈值。  
  19.   //指定在每个金字塔层,为某点寻找光流的迭代过程的终止条件。  
  20.   cv::TermCriteria term_criteria;  
  21.   float lambda;   //某阈值??Lagrangian 乘子  
  22.   // NCC 归一化交叉相关,FB error与NCC结合,使跟踪更稳定  交叉相关的图像匹配算法??  
  23.   //交叉相关法的作用是进行云团移动的短时预测。选取连续两个时次的GMS-5卫星云图,将云图区域划分为32×32像素  
  24.   //的图像子集,采用交叉相关法计算获取两幅云图的最佳匹配区域,根据前后云图匹配区域的位置和时间间隔,确  
  25.   //定出每个图像子集的移动矢量(速度和方向),并对图像子集的移动矢量进行客观分析,其后,基于检验后的云  
  26.   //图移动矢量集,利用后向轨迹方法对云图作短时外推预测。  
  27.   void normCrossCorrelation(const cv::Mat& img1, const cv::Mat& img2, std::vector<cv::Point2f>& points1, std::vector<cv::Point2f>& points2);  
  28.   bool filterPts(std::vector<cv::Point2f>& points1,std::vector<cv::Point2f>& points2);  
  29. public:  
  30.   LKTracker();  
  31.   //特征点的跟踪??  
  32.   bool trackf2f(const cv::Mat& img1, const cv::Mat& img2,  
  33.                 std::vector<cv::Point2f> &points1, std::vector<cv::Point2f> &points2);  
  34.   float getFB(){return fbmed;}  
  35. };  


 

LKTracker.cpp

[cpp]  view plain copy
  1. #include <LKTracker.h>  
  2. using namespace cv;  
  3.   
  4. //金字塔LK光流法跟踪  
  5. //Media Flow 中值光流跟踪 加 跟踪错误检测  
  6. //构造函数,初始化成员变量  
  7. LKTracker::LKTracker(){  
  8.   该类变量需要3个参数,一个是类型,第二个参数为迭代的最大次数,最后一个是特定的阈值。  
  9.   term_criteria = TermCriteria( TermCriteria::COUNT + TermCriteria::EPS, 20, 0.03);  
  10.   window_size = Size(4,4);  
  11.   level = 5;  
  12.   lambda = 0.5;  
  13. }  
  14.   
  15.   
  16. bool LKTracker::trackf2f(const Mat& img1, const Mat& img2, vector<Point2f> &points1, vector<cv::Point2f> &points2){  
  17.   //TODO!:implement c function cvCalcOpticalFlowPyrLK() or Faster tracking function  
  18.   //Forward-Backward tracking  
  19.   //基于Forward-Backward Error的中值流跟踪方法  
  20.   //金字塔LK光流法跟踪  
  21.   //forward trajectory 前向轨迹跟踪  
  22.   calcOpticalFlowPyrLK( img1,img2, points1, points2, status, similarity, window_size, level, term_criteria, lambda, 0);  
  23.   //backward trajectory 后向轨迹跟踪  
  24.   calcOpticalFlowPyrLK( img2,img1, points2, pointsFB, FB_status,FB_error, window_size, level, term_criteria, lambda, 0);  
  25.     
  26.   //Compute the real FB-error  
  27.   //原理很简单:从t时刻的图像的A点,跟踪到t+1时刻的图像B点;然后倒回来,从t+1时刻的图像的B点往回跟踪,  
  28.   //假如跟踪到t时刻的图像的C点,这样就产生了前向和后向两个轨迹,比较t时刻中 A点 和 C点 的距离,如果距离  
  29.   //小于一个阈值,那么就认为前向跟踪是正确的;这个距离就是FB_error  
  30.   //计算 前向 与 后向 轨迹的误差  
  31.   forint i= 0; i<points1.size(); ++i ){  
  32.         FB_error[i] = norm(pointsFB[i]-points1[i]);   //norm()求矩阵或向量的范数??绝对值?  
  33.   }  
  34.   //Filter out points with FB_error[i] <= median(FB_error) && points with sim_error[i] > median(sim_error)  
  35.   normCrossCorrelation(img1, img2, points1, points2);  
  36.   return filterPts(points1, points2);  
  37. }  
  38.   
  39. //利用NCC把跟踪预测的结果周围取10*10的小图片与原始位置周围10*10的小图片(使用函数getRectSubPix得到)进  
  40. //行模板匹配(调用matchTemplate)  
  41. void LKTracker::normCrossCorrelation(const Mat& img1,const Mat& img2, vector<Point2f>& points1, vector<Point2f>& points2) {  
  42.         Mat rec0(10,10,CV_8U);  
  43.         Mat rec1(10,10,CV_8U);  
  44.         Mat res(1,1,CV_32F);  
  45.   
  46.         for (int i = 0; i < points1.size(); i++) {  
  47.                 if (status[i] == 1) {  //为1表示该特征点跟踪成功  
  48.                         //从前一帧和当前帧图像中(以每个特征点为中心?)提取10x10象素矩形,使用亚象素精度  
  49.                         getRectSubPix( img1, Size(10,10), points1[i],rec0 );     
  50.                         getRectSubPix( img2, Size(10,10), points2[i],rec1);  
  51.                         //匹配前一帧和当前帧中提取的10x10象素矩形,得到匹配后的映射图像  
  52.                         //CV_TM_CCOEFF_NORMED 归一化相关系数匹配法  
  53.                         //参数分别为:欲搜索的图像。搜索模板。比较结果的映射图像。指定匹配方法  
  54.                         matchTemplate( rec0,rec1, res, CV_TM_CCOEFF_NORMED);   
  55.                         similarity[i] = ((float *)(res.data))[0];  //得到各个特征点的相似度大小  
  56.   
  57.                 } else {  
  58.                         similarity[i] = 0.0;  
  59.                 }  
  60.         }  
  61.         rec0.release();  
  62.         rec1.release();  
  63.         res.release();  
  64. }  
  65.   
  66. //筛选出 FB_error[i] <= median(FB_error) 和 sim_error[i] > median(sim_error) 的特征点  
  67. //得到NCC和FB error结果的中值,分别去掉中值一半的跟踪结果不好的点  
  68. bool LKTracker::filterPts(vector<Point2f>& points1,vector<Point2f>& points2){  
  69.   //Get Error Medians  
  70.   simmed = median(similarity);   //找到相似度的中值  
  71.   size_t i, k;  
  72.   for( i=k = 0; i<points2.size(); ++i ){  
  73.         if( !status[i])  
  74.           continue;  
  75.         if(similarity[i]> simmed){   //剩下 similarity[i]> simmed 的特征点  
  76.           points1[k] = points1[i];     
  77.           points2[k] = points2[i];  
  78.           FB_error[k] = FB_error[i];  
  79.           k++;  
  80.         }  
  81.     }  
  82.   if (k==0)  
  83.     return false;  
  84.   points1.resize(k);  
  85.   points2.resize(k);  
  86.   FB_error.resize(k);  
  87.   
  88.   fbmed = median(FB_error);     //找到FB_error的中值  
  89.   for( i=k = 0; i<points2.size(); ++i ){  
  90.       if( !status[i])  
  91.         continue;  
  92.       if(FB_error[i] <= fbmed){   /  
  93.         points1[k] = points1[i];   //再对上一步剩下的特征点进一步筛选,剩下 FB_error[i] <= fbmed 的特征点  
  94.         points2[k] = points2[i];  
  95.         k++;  
  96.       }  
  97.   }  
  98.   points1.resize(k);  
  99.   points2.resize(k);  
  100.   if (k>0)  
  101.     return true;  
  102.   else  
  103.     return false;  
  104. }  
  105.   
  106.   
  107.   
  108.   
  109. /* 
  110.  * old OpenCV style 
  111. void LKTracker::init(Mat img0, vector<Point2f> &points){ 
  112.   //Preallocate 
  113.   //pyr1 = cvCreateImage(Size(img1.width+8,img1.height/3),IPL_DEPTH_32F,1); 
  114.   //pyr2 = cvCreateImage(Size(img1.width+8,img1.height/3),IPL_DEPTH_32F,1); 
  115.   //const int NUM_PTS = points.size(); 
  116.   //status = new char[NUM_PTS]; 
  117.   //track_error = new float[NUM_PTS]; 
  118.   //FB_error = new float[NUM_PTS]; 
  119. } 
  120.  
  121.  
  122. void LKTracker::trackf2f(..){ 
  123.   cvCalcOpticalFlowPyrLK( &img1, &img2, pyr1, pyr1, points1, points2, points1.size(), window_size, level, status, track_error, term_criteria, CV_LKFLOW_INITIAL_GUESSES); 
  124.   cvCalcOpticalFlowPyrLK( &img2, &img1, pyr2, pyr1, points2, pointsFB, points2.size(),window_size, level, 0, 0, term_criteria, CV_LKFLOW_INITIAL_GUESSES | CV_LKFLOW_PYR_A_READY | CV_LKFLOW_PYR_B_READY ); 
  125. } 
  126. */  

下面是自己在看论文和这些大牛的分析过程中,对代码进行了一些理解,但是由于自己接触图像处理和机器视觉没多久,另外由于自己编程能力比较弱,所以分析过程可能会有不少的错误,希望各位不吝指正。而且,因为编程很多地方不懂,所以注释得非常乱,还海涵。

TLD.h

[cpp]  view plain copy
  1. #include <opencv2/opencv.hpp>  
  2. #include <tld_utils.h>  
  3. #include <LKTracker.h>  
  4. #include <FerNNClassifier.h>  
  5. #include <fstream>  
  6.   
  7.   
  8. //Bounding Boxes  
  9. struct BoundingBox : public cv::Rect {  
  10.   BoundingBox(){}  
  11.   BoundingBox(cv::Rect r): cv::Rect(r){}   //继承的话需要初始化基类  
  12. public:  
  13.   float overlap;        //Overlap with current Bounding Box  
  14.   int sidx;             //scale index  
  15. };  
  16.   
  17. //Detection structure  
  18. struct DetStruct {  
  19.     std::vector<int> bb;  
  20.     std::vector<std::vector<int> > patt;  
  21.     std::vector<float> conf1;  
  22.     std::vector<float> conf2;  
  23.     std::vector<std::vector<int> > isin;  
  24.     std::vector<cv::Mat> patch;  
  25.   };  
  26.     
  27. //Temporal structure  
  28. struct TempStruct {  
  29.     std::vector<std::vector<int> > patt;  
  30.     std::vector<float> conf;  
  31.   };  
  32.   
  33. struct OComparator{  //比较两者重合度  
  34.   OComparator(const std::vector<BoundingBox>& _grid):grid(_grid){}  
  35.   std::vector<BoundingBox> grid;  
  36.   bool operator()(int idx1,int idx2){  
  37.     return grid[idx1].overlap > grid[idx2].overlap;  
  38.   }  
  39. };  
  40.   
  41. struct CComparator{  //比较两者确信度?  
  42.   CComparator(const std::vector<float>& _conf):conf(_conf){}  
  43.   std::vector<float> conf;  
  44.   bool operator()(int idx1,int idx2){  
  45.     return conf[idx1]> conf[idx2];  
  46.   }  
  47. };  
  48.   
  49.   
  50. class TLD{  
  51. private:  
  52.   cv::PatchGenerator generator;  //PatchGenerator类用来对图像区域进行仿射变换  
  53.   FerNNClassifier classifier;  
  54.   LKTracker tracker;  
  55.     
  56.   //下面这些参数通过程序开始运行时读入parameters.yml文件进行初始化  
  57.   ///Parameters  
  58.   int bbox_step;  
  59.   int min_win;  
  60.   int patch_size;  
  61.     
  62.   //initial parameters for positive examples  
  63.   //从第一帧得到的目标的bounding box中(文件读取或者用户框定),经过几何变换得  
  64.   //到 num_closest_init * num_warps_init 个正样本  
  65.   int num_closest_init;  //最近邻窗口数 10  
  66.   int num_warps_init;  //几何变换数目 20  
  67.   int noise_init;  
  68.   float angle_init;  
  69.   float shift_init;  
  70.   float scale_init;  
  71.     
  72.   从跟踪得到的目标的bounding box中,经过几何变换更新正样本(添加到在线模型?)  
  73.   //update parameters for positive examples  
  74.   int num_closest_update;  
  75.   int num_warps_update;  
  76.   int noise_update;  
  77.   float angle_update;  
  78.   float shift_update;  
  79.   float scale_update;  
  80.     
  81.   //parameters for negative examples  
  82.   float bad_overlap;  
  83.   float bad_patches;  
  84.     
  85.   ///Variables  
  86. //Integral Images  积分图像,用以计算2bitBP特征(类似于haar特征的计算)  
  87. //Mat最大的优势跟STL很相似,都是对内存进行动态的管理,不需要之前用户手动的管理内存  
  88.   cv::Mat iisum;  
  89.   cv::Mat iisqsum;  
  90.   float var;  
  91.     
  92. //Training data  
  93.   //std::pair主要的作用是将两个数据组合成一个数据,两个数据可以是同一类型或者不同类型。  
  94.   //pair实质上是一个结构体,其主要的两个成员变量是first和second,这两个变量可以直接使用。  
  95.   //在这里用来表示样本,first成员为 features 特征点数组,second成员为 labels 样本类别标签  
  96.   std::vector<std::pair<std::vector<int>,int> > pX; //positive ferns <features,labels=1>  正样本  
  97.   std::vector<std::pair<std::vector<int>,int> > nX; // negative ferns <features,labels=0>  负样本  
  98.   cv::Mat pEx;  //positive NN example    
  99.   std::vector<cv::Mat> nEx; //negative NN examples  
  100.     
  101. //Test data   
  102.   std::vector<std::pair<std::vector<int>,int> > nXT; //negative data to Test  
  103.   std::vector<cv::Mat> nExT; //negative NN examples to Test  
  104.     
  105. //Last frame data  
  106.   BoundingBox lastbox;  
  107.   bool lastvalid;  
  108.   float lastconf;  
  109.     
  110. //Current frame data  
  111.   //Tracker data  
  112.   bool tracked;  
  113.   BoundingBox tbb;  
  114.   bool tvalid;  
  115.   float tconf;  
  116.     
  117.   //Detector data  
  118.   TempStruct tmp;  
  119.   DetStruct dt;  
  120.   std::vector<BoundingBox> dbb;  
  121.   std::vector<bool> dvalid;   //检测有效性??  
  122.   std::vector<float> dconf;  //检测确信度??  
  123.   bool detected;  
  124.   
  125.   
  126.   //Bounding Boxes  
  127.   std::vector<BoundingBox> grid;  
  128.   std::vector<cv::Size> scales;  
  129.   std::vector<int> good_boxes; //indexes of bboxes with overlap > 0.6  
  130.   std::vector<int> bad_boxes; //indexes of bboxes with overlap < 0.2  
  131.   BoundingBox bbhull; // hull of good_boxes  //good_boxes 的 壳,也就是窗口的边框  
  132.   BoundingBox best_box; // maximum overlapping bbox  
  133.   
  134. public:  
  135.   //Constructors  
  136.   TLD();  
  137.   TLD(const cv::FileNode& file);  
  138.   void read(const cv::FileNode& file);  
  139.     
  140.   //Methods  
  141.   void init(const cv::Mat& frame1,const cv::Rect &box, FILE* bb_file);  
  142.   void generatePositiveData(const cv::Mat& frame, int num_warps);  
  143.   void generateNegativeData(const cv::Mat& frame);  
  144.   void processFrame(const cv::Mat& img1,const cv::Mat& img2,std::vector<cv::Point2f>& points1,std::vector<cv::Point2f>& points2,  
  145.       BoundingBox& bbnext,bool& lastboxfound, bool tl,FILE* bb_file);  
  146.   void track(const cv::Mat& img1, const cv::Mat& img2,std::vector<cv::Point2f>& points1,std::vector<cv::Point2f>& points2);  
  147.   void detect(const cv::Mat& frame);  
  148.   void clusterConf(const std::vector<BoundingBox>& dbb,const std::vector<float>& dconf,std::vector<BoundingBox>& cbb,std::vector<float>& cconf);  
  149.   void evaluate();  
  150.   void learn(const cv::Mat& img);  
  151.     
  152.   //Tools  
  153.   void buildGrid(const cv::Mat& img, const cv::Rect& box);  
  154.   float bbOverlap(const BoundingBox& box1,const BoundingBox& box2);  
  155.   void getOverlappingBoxes(const cv::Rect& box1,int num_closest);  
  156.   void getBBHull();  
  157.   void getPattern(const cv::Mat& img, cv::Mat& pattern,cv::Scalar& mean,cv::Scalar& stdev);  
  158.   void bbPoints(std::vector<cv::Point2f>& points, const BoundingBox& bb);  
  159.   void bbPredict(const std::vector<cv::Point2f>& points1,const std::vector<cv::Point2f>& points2,  
  160.       const BoundingBox& bb1,BoundingBox& bb2);  
  161.   double getVar(const BoundingBox& box,const cv::Mat& sum,const cv::Mat& sqsum);  
  162.   bool bbComp(const BoundingBox& bb1,const BoundingBox& bb2);  
  163.   int clusterBB(const std::vector<BoundingBox>& dbb,std::vector<int>& indexes);  
  164. };  


TLD.cpp

[cpp]  view plain copy
  1. /* 
  2.  * TLD.cpp 
  3.  * 
  4.  *  Created on: Jun 9, 2011 
  5.  *      Author: alantrrs 
  6.  */  
  7.   
  8. #include <TLD.h>  
  9. #include <stdio.h>  
  10. using namespace cv;  
  11. using namespace std;  
  12.   
  13.   
  14. TLD::TLD()  
  15. {  
  16. }  
  17. TLD::TLD(const FileNode& file){  
  18.   read(file);  
  19. }  
  20.   
  21. void TLD::read(const FileNode& file){  
  22.   ///Bounding Box Parameters  
  23.   min_win = (int)file["min_win"];  
  24.   ///Genarator Parameters  
  25.   //initial parameters for positive examples  
  26.   patch_size = (int)file["patch_size"];  
  27.   num_closest_init = (int)file["num_closest_init"];  
  28.   num_warps_init = (int)file["num_warps_init"];  
  29.   noise_init = (int)file["noise_init"];  
  30.   angle_init = (float)file["angle_init"];  
  31.   shift_init = (float)file["shift_init"];  
  32.   scale_init = (float)file["scale_init"];  
  33.   //update parameters for positive examples  
  34.   num_closest_update = (int)file["num_closest_update"];  
  35.   num_warps_update = (int)file["num_warps_update"];  
  36.   noise_update = (int)file["noise_update"];  
  37.   angle_update = (float)file["angle_update"];  
  38.   shift_update = (float)file["shift_update"];  
  39.   scale_update = (float)file["scale_update"];  
  40.   //parameters for negative examples  
  41.   bad_overlap = (float)file["overlap"];  
  42.   bad_patches = (int)file["num_patches"];  
  43.   classifier.read(file);  
  44. }  
  45.   
  46. //此函数完成准备工作  
  47. void TLD::init(const Mat& frame1, const Rect& box, FILE* bb_file){  
  48.   //bb_file = fopen("bounding_boxes.txt","w");  
  49.   //Get Bounding Boxes  
  50.   //此函数根据传入的box(目标边界框)在传入的图像frame1中构建全部的扫描窗口,并计算重叠度  
  51.     buildGrid(frame1, box);  
  52.     printf("Created %d bounding boxes\n",(int)grid.size());  //vector的成员size()用于获取向量元素的个数  
  53.       
  54.   ///Preparation  
  55.   //allocation  
  56.   //积分图像,用以计算2bitBP特征(类似于haar特征的计算)  
  57.   //Mat的创建,方式有两种:1.调用create(行,列,类型)2.Mat(行,列,类型(值))。  
  58.   iisum.create(frame1.rows+1, frame1.cols+1, CV_32F);  
  59.   iisqsum.create(frame1.rows+1, frame1.cols+1, CV_64F);  
  60.     
  61.   //Detector data中定义:std::vector<float> dconf;  检测确信度??  
  62.   //vector 的reserve增加了vector的capacity,但是它的size没有改变!而resize改变了vector  
  63.   //的capacity同时也增加了它的size!reserve是容器预留空间,但在空间内不真正创建元素对象,  
  64.   //所以在没有添加新的对象之前,不能引用容器内的元素。  
  65.   //不管是调用resize还是reserve,二者对容器原有的元素都没有影响。  
  66.   //myVec.reserve( 100 );     // 新元素还没有构造, 此时不能用[]访问元素  
  67.   //myVec.resize( 100 );      // 用元素的默认构造函数构造了100个新的元素,可以直接操作新元素  
  68.   dconf.reserve(100);  
  69.   dbb.reserve(100);  
  70.   bbox_step =7;  
  71.     
  72.   //以下在Detector data中定义的容器都给其分配grid.size()大小(这个是一幅图像中全部的扫描窗口个数)的容量  
  73.   //Detector data中定义TempStruct tmp;    
  74.   //tmp.conf.reserve(grid.size());  
  75.   tmp.conf = vector<float>(grid.size());  
  76.   tmp.patt = vector<vector<int> >(grid.size(), vector<int>(10,0));  
  77.   //tmp.patt.reserve(grid.size());  
  78.   dt.bb.reserve(grid.size());  
  79.   good_boxes.reserve(grid.size());  
  80.   bad_boxes.reserve(grid.size());  
  81.     
  82.   //TLD中定义:cv::Mat pEx;  //positive NN example 大小为15*15图像片  
  83.   pEx.create(patch_size, patch_size, CV_64F);  
  84.     
  85.   //Init Generator  
  86.   //TLD中定义:cv::PatchGenerator generator;  //PatchGenerator类用来对图像区域进行仿射变换  
  87.   /* 
  88.   cv::PatchGenerator::PatchGenerator (     
  89.       double     _backgroundMin, 
  90.       double     _backgroundMax, 
  91.       double     _noiseRange, 
  92.       bool     _randomBlur = true, 
  93.       double     _lambdaMin = 0.6, 
  94.       double     _lambdaMax = 1.5, 
  95.       double     _thetaMin = -CV_PI, 
  96.       double     _thetaMax = CV_PI, 
  97.       double     _phiMin = -CV_PI, 
  98.       double     _phiMax = CV_PI  
  99.    )  
  100.    一般的用法是先初始化一个PatchGenerator的实例,然后RNG一个随机因子,再调用()运算符产生一个变换后的正样本。 
  101.   */  
  102.   generator = PatchGenerator (0,0,noise_init,true,1-scale_init,1+scale_init,-angle_init*CV_PI/180,  
  103.                                 angle_init*CV_PI/180,-angle_init*CV_PI/180,angle_init*CV_PI/180);  
  104.     
  105.   //此函数根据传入的box(目标边界框),在整帧图像中的全部窗口中寻找与该box距离最小(即最相似,  
  106.   //重叠度最大)的num_closest_init个窗口,然后把这些窗口 归入good_boxes容器  
  107.   //同时,把重叠度小于0.2的,归入 bad_boxes 容器  
  108.   //首先根据overlap的比例信息选出重复区域比例大于60%并且前num_closet_init= 10个的最接近box的RectBox,  
  109.   //相当于对RectBox进行筛选。并通过BBhull函数得到这些RectBox的最大边界。  
  110.   getOverlappingBoxes(box, num_closest_init);  
  111.   printf("Found %d good boxes, %d bad boxes\n",(int)good_boxes.size(),(int)bad_boxes.size());  
  112.   printf("Best Box: %d %d %d %d\n",best_box.x, best_box.y, best_box.width, best_box.height);  
  113.   printf("Bounding box hull: %d %d %d %d\n", bbhull.x, bbhull.y, bbhull.width, bbhull.height);  
  114.     
  115.   //Correct Bounding Box  
  116.   lastbox=best_box;  
  117.   lastconf=1;  
  118.   lastvalid=true;  
  119.   //Print  
  120.   fprintf(bb_file,"%d,%d,%d,%d,%f\n",lastbox.x,lastbox.y,lastbox.br().x,lastbox.br().y,lastconf);  
  121.     
  122.   //Prepare Classifier 准备分类器  
  123.   //scales容器里是所有扫描窗口的尺度,由buildGrid()函数初始化  
  124.   classifier.prepare(scales);  
  125.     
  126.   ///Generate Data  
  127.   // Generate positive data  
  128.   generatePositiveData(frame1, num_warps_init);  
  129.     
  130.   // Set variance threshold  
  131.   Scalar stdev, mean;  
  132.   //统计best_box的均值和标准差  
  133.   例如需要提取图像A的某个ROI(感兴趣区域,由矩形框)的话,用Mat类的B=img(ROI)即可提取  
  134.   //frame1(best_box)就表示在frame1中提取best_box区域(目标区域)的图像片  
  135.   meanStdDev(frame1(best_box), mean, stdev);  
  136.     
  137.   //利用积分图像去计算每个待检测窗口的方差  
  138.   //cvIntegral( const CvArr* image, CvArr* sum, CvArr* sqsum=NULL, CvArr* tilted_sum=NULL );  
  139.   //计算积分图像,输入图像,sum积分图像, W+1×H+1,sqsum对象素值平方的积分图像,tilted_sum旋转45度的积分图像  
  140.   //利用积分图像,可以计算在某象素的上-右方的或者旋转的矩形区域中进行求和、求均值以及标准方差的计算,  
  141.   //并且保证运算的复杂度为O(1)。    
  142.   integral(frame1, iisum, iisqsum);  
  143.   //级联分类器模块一:方差检测模块,利用积分图计算每个待检测窗口的方差,方差大于var阈值(目标patch方差的50%)的,  
  144.   //则认为其含有前景目标方差;var 为标准差的平方  
  145.   var = pow(stdev.val[0],2) * 0.5; //getVar(best_box,iisum,iisqsum);  
  146.   cout << "variance: " << var << endl;  
  147.     
  148.   //check variance  
  149.   //getVar函数通过积分图像计算输入的best_box的方差  
  150.   double vr =  getVar(best_box, iisum, iisqsum)*0.5;  
  151.   cout << "check variance: " << vr << endl;  
  152.     
  153.   // Generate negative data  
  154.   generateNegativeData(frame1);  
  155.     
  156.   //Split Negative Ferns into Training and Testing sets (they are already shuffled)  
  157.   //将负样本放进 训练和测试集  
  158.   int half = (int)nX.size()*0.5f;  
  159.   //vector::assign函数将区间[start, end)中的值赋值给当前的vector.  
  160.   //将一半的负样本集 作为 测试集  
  161.   nXT.assign(nX.begin()+half, nX.end());  //nXT; //negative data to Test  
  162.   //然后将剩下的一半作为训练集  
  163.   nX.resize(half);  
  164.     
  165.   ///Split Negative NN Examples into Training and Testing sets  
  166.   half = (int)nEx.size()*0.5f;  
  167.   nExT.assign(nEx.begin()+half,nEx.end());  
  168.   nEx.resize(half);  
  169.     
  170.   //Merge Negative Data with Positive Data and shuffle it  
  171.   //将负样本和正样本合并,然后打乱  
  172.   vector<pair<vector<int>,int> > ferns_data(nX.size()+pX.size());  
  173.   vector<int> idx = index_shuffle(0, ferns_data.size());  
  174.   int a=0;  
  175.   for (int i=0;i<pX.size();i++){  
  176.       ferns_data[idx[a]] = pX[i];  
  177.       a++;  
  178.   }  
  179.   for (int i=0;i<nX.size();i++){  
  180.       ferns_data[idx[a]] = nX[i];  
  181.       a++;  
  182.   }  
  183.     
  184.   //Data already have been shuffled, just putting it in the same vector  
  185.   vector<cv::Mat> nn_data(nEx.size()+1);  
  186.   nn_data[0] = pEx;  
  187.   for (int i=0;i<nEx.size();i++){  
  188.       nn_data[i+1]= nEx[i];  
  189.   }  
  190.     
  191.   ///Training    
  192.   //训练 集合分类器(森林) 和 最近邻分类器   
  193.   classifier.trainF(ferns_data, 2); //bootstrap = 2  
  194.   classifier.trainNN(nn_data);  
  195.     
  196.   ///Threshold Evaluation on testing sets  
  197.   //用样本在上面得到的 集合分类器(森林) 和 最近邻分类器 中分类,评价得到最好的阈值  
  198.   classifier.evaluateTh(nXT, nExT);  
  199. }  
  200.   
  201. /* Generate Positive data 
  202.  * Inputs: 
  203.  * - good_boxes (bbP) 
  204.  * - best_box (bbP0) 
  205.  * - frame (im0) 
  206.  * Outputs: 
  207.  * - Positive fern features (pX) 
  208.  * - Positive NN examples (pEx) 
  209.  */  
  210. void TLD::generatePositiveData(const Mat& frame, int num_warps){  
  211.     /* 
  212.     CvScalar定义可存放1—4个数值的数值,常用来存储像素,其结构体如下: 
  213.     typedef struct CvScalar 
  214.     { 
  215.         double val[4]; 
  216.     }CvScalar; 
  217.     如果使用的图像是1通道的,则s.val[0]中存储数据 
  218.     如果使用的图像是3通道的,则s.val[0],s.val[1],s.val[2]中存储数据 
  219.     */  
  220.   Scalar mean;   //均值  
  221.   Scalar stdev;   //标准差  
  222.     
  223.   //此函数将frame图像best_box区域的图像片归一化为均值为0的15*15大小的patch,存在pEx正样本中  
  224.   getPattern(frame(best_box), pEx, mean, stdev);  
  225.     
  226.   //Get Fern features on warped patches  
  227.   Mat img;  
  228.   Mat warped;  
  229.   //void GaussianBlur(InputArray src, OutputArray dst, Size ksize, double sigmaX, double sigmaY=0,   
  230.   //                                    int borderType=BORDER_DEFAULT ) ;  
  231.   //功能:对输入的图像src进行高斯滤波后用dst输出。  
  232.   //src和dst当然分别是输入图像和输出图像。Ksize为高斯滤波器模板大小,sigmaX和sigmaY分别为高斯滤  
  233.   //波在横向和竖向的滤波系数。borderType为边缘扩展点插值类型。  
  234.   //用9*9高斯核模糊输入帧,存入img  去噪??  
  235.   GaussianBlur(frame, img, Size(9,9), 1.5);  
  236.     
  237.   //在img图像中截取bbhull信息(bbhull是包含了位置和大小的矩形框)的图像赋给warped  
  238.   //例如需要提取图像A的某个ROI(感兴趣区域,由矩形框)的话,用Mat类的B=img(ROI)即可提取  
  239.   warped = img(bbhull);  
  240.   RNG& rng = theRNG();  //生成一个随机数  
  241.   Point2f pt(bbhull.x + (bbhull.width-1)*0.5f, bbhull.y+(bbhull.height-1)*0.5f);  //取矩形框中心的坐标  int i(2)  
  242.     
  243.   //nstructs树木(由一个特征组构建,每组特征代表图像块的不同视图表示)的个数  
  244.   //fern[nstructs] nstructs棵树的森林的数组??  
  245.   vector<int> fern(classifier.getNumStructs());  
  246.   pX.clear();  
  247.   Mat patch;  
  248.   
  249.   //pX为处理后的RectBox最大边界处理后的像素信息,pEx最近邻的RectBox的Pattern,bbP0为最近邻的RectBox。  
  250.   if (pX.capacity() < num_warps * good_boxes.size())  
  251.     pX.reserve(num_warps * good_boxes.size());  //pX正样本个数为 仿射变换个数 * good_box的个数,故需分配至少这么大的空间  
  252.   int idx;  
  253.   for (int i=0; i< num_warps; i++){  
  254.      if (i>0)  
  255.      //PatchGenerator类用来对图像区域进行仿射变换,先RNG一个随机因子,再调用()运算符产生一个变换后的正样本。  
  256.        generator(frame, pt, warped, bbhull.size(), rng);  
  257.        for (int b=0; b < good_boxes.size(); b++){  
  258.          idx = good_boxes[b];  //good_boxes容器保存的是 grid 的索引  
  259.          patch = img(grid[idx]);  //把img的 grid[idx] 区域(也就是bounding box重叠度高的)这一块图像片提取出来  
  260.          //getFeatures函数得到输入的patch的用于树的节点,也就是特征组的特征fern(13位的二进制代码)  
  261.          classifier.getFeatures(patch, grid[idx].sidx, fern);  //grid[idx].sidx 对应的尺度索引  
  262.          pX.push_back(make_pair(fern, 1));   //positive ferns <features, labels=1>  正样本  
  263.      }  
  264.   }  
  265.   printf("Positive examples generated: ferns:%d NN:1\n",(int)pX.size());  
  266. }  
  267.   
  268. //先对最接近box的RectBox区域得到其patch ,然后将像素信息转换为Pattern,  
  269. //具体的说就是归一化RectBox对应的patch的size(放缩至patch_size = 15*15),将2维的矩阵变成一维的向量信息,  
  270. //然后将向量信息均值设为0,调整为zero mean and unit variance(ZMUV)  
  271. //Output: resized Zero-Mean patch  
  272. void TLD::getPattern(const Mat& img, Mat& pattern, Scalar& mean, Scalar& stdev){  
  273.   //将img放缩至patch_size = 15*15,存到pattern中  
  274.   resize(img, pattern, Size(patch_size, patch_size));  
  275.     
  276.   //计算pattern这个矩阵的均值和标准差  
  277.   //Computes a mean value and a standard deviation of matrix elements.  
  278.   meanStdDev(pattern, mean, stdev);  
  279.   pattern.convertTo(pattern, CV_32F);  
  280.     
  281.   //opencv中Mat的运算符有重载, Mat可以 + Mat; + Scalar; + int / float / double 都可以  
  282.   //将矩阵所有元素减去其均值,也就是把patch的均值设为零  
  283.   pattern = pattern - mean.val[0];  
  284. }  
  285.   
  286. /* Inputs: 
  287.  * - Image 
  288.  * - bad_boxes (Boxes far from the bounding box) 
  289.  * - variance (pEx variance) 
  290.  * Outputs 
  291.  * - Negative fern features (nX) 
  292.  * - Negative NN examples (nEx) 
  293.  */  
  294. void TLD::generateNegativeData(const Mat& frame){  
  295.   //由于之前重叠度小于0.2的,都归入 bad_boxes了,所以数量挺多,下面的函数用于打乱顺序,也就是为了  
  296.   //后面随机选择bad_boxes  
  297.   random_shuffle(bad_boxes.begin(), bad_boxes.end());//Random shuffle bad_boxes indexes  
  298.   int idx;  
  299.   //Get Fern Features of the boxes with big variance (calculated using integral images)  
  300.   int a=0;  
  301.   //int num = std::min((int)bad_boxes.size(),(int)bad_patches*100); //limits the size of bad_boxes to try  
  302.   printf("negative data generation started.\n");  
  303.   vector<int> fern(classifier.getNumStructs());  
  304.   nX.reserve(bad_boxes.size());  
  305.   Mat patch;  
  306.   for (int j=0;j<bad_boxes.size();j++){  //把方差较大的bad_boxes加入负样本  
  307.       idx = bad_boxes[j];  
  308.           if (getVar(grid[idx],iisum,iisqsum)<var*0.5f)  
  309.             continue;  
  310.       patch =  frame(grid[idx]);  
  311.       classifier.getFeatures(patch, grid[idx].sidx, fern);  
  312.       nX.push_back(make_pair(fern, 0)); //得到负样本  
  313.       a++;  
  314.   }  
  315.   printf("Negative examples generated: ferns: %d ", a);  
  316.     
  317.   //random_shuffle(bad_boxes.begin(),bad_boxes.begin()+bad_patches);//Randomly selects 'bad_patches' and get the patterns for NN;  
  318.   Scalar dum1, dum2;  
  319.   //bad_patches = (int)file["num_patches"]; 在参数文件中 num_patches = 100  
  320.   nEx=vector<Mat>(bad_patches);  
  321.   for (int i=0;i<bad_patches;i++){  
  322.       idx=bad_boxes[i];  
  323.       patch = frame(grid[idx]);  
  324.       //具体的说就是归一化RectBox对应的patch的size(放缩至patch_size = 15*15)  
  325.       //由于负样本不需要均值和方差,所以就定义dum,将其舍弃  
  326.       getPattern(patch,nEx[i],dum1,dum2);  
  327.   }  
  328.   printf("NN: %d\n",(int)nEx.size());  
  329. }  
  330.   
  331. //该函数通过积分图像计算输入的box的方差  
  332. double TLD::getVar(const BoundingBox& box, const Mat& sum, const Mat& sqsum){  
  333.   double brs = sum.at<int>(box.y+box.height, box.x+box.width);  
  334.   double bls = sum.at<int>(box.y+box.height, box.x);  
  335.   double trs = sum.at<int>(box.y,box.x + box.width);  
  336.   double tls = sum.at<int>(box.y,box.x);  
  337.   double brsq = sqsum.at<double>(box.y+box.height,box.x+box.width);  
  338.   double blsq = sqsum.at<double>(box.y+box.height,box.x);  
  339.   double trsq = sqsum.at<double>(box.y,box.x+box.width);  
  340.   double tlsq = sqsum.at<double>(box.y,box.x);  
  341.     
  342.   double mean = (brs+tls-trs-bls)/((double)box.area());  
  343.   double sqmean = (brsq+tlsq-trsq-blsq)/((double)box.area());  
  344.   //方差=E(X^2)-(EX)^2   EX表示均值  
  345.   return sqmean-mean*mean;  
  346. }  
  347.   
  348. void TLD::processFrame(const cv::Mat& img1,const cv::Mat& img2,vector<Point2f>& points1,vector<Point2f>& points2,BoundingBox& bbnext, bool& lastboxfound, bool tl, FILE* bb_file){  
  349.   vector<BoundingBox> cbb;  
  350.   vector<float> cconf;  
  351.   int confident_detections=0;  
  352.   int didx; //detection index  
  353.     
  354.   ///Track  跟踪模块  
  355.   if(lastboxfound && tl){   //tl: train and learn  
  356.       //跟踪  
  357.       track(img1, img2, points1, points2);  
  358.   }  
  359.   else{  
  360.       tracked = false;  
  361.   }  
  362.     
  363.   ///Detect   检测模块  
  364.   detect(img2);  
  365.     
  366.   ///Integration   综合模块  
  367.   //TLD只跟踪单目标,所以综合模块综合跟踪器跟踪到的单个目标和检测器检测到的多个目标,然后只输出保守相似度最大的一个目标  
  368.   if (tracked){  
  369.       bbnext=tbb;  
  370.       lastconf=tconf;   //表示相关相似度的阈值  
  371.       lastvalid=tvalid;  //表示保守相似度的阈值  
  372.       printf("Tracked\n");  
  373.       if(detected){                                               //   if Detected  
  374.           //通过 重叠度 对检测器检测到的目标bounding box进行聚类,每个类其重叠度小于0.5  
  375.           clusterConf(dbb, dconf, cbb, cconf);                       //   cluster detections  
  376.           printf("Found %d clusters\n",(int)cbb.size());  
  377.           for (int i=0;i<cbb.size();i++){  
  378.               //找到与跟踪器跟踪到的box距离比较远的类(检测器检测到的box),而且它的相关相似度比跟踪器的要大  
  379.               if (bbOverlap(tbb, cbb[i])<0.5 && cconf[i]>tconf){  //  Get index of a clusters that is far from tracker and are more confident than the tracker  
  380.                   confident_detections++;  //记录满足上述条件,也就是可信度比较高的目标box的个数  
  381.                   didx=i; //detection index  
  382.               }  
  383.           }  
  384.           //如果只有一个满足上述条件的box,那么就用这个目标box来重新初始化跟踪器(也就是用检测器的结果去纠正跟踪器)  
  385.           if (confident_detections==1){                                //if there is ONE such a cluster, re-initialize the tracker  
  386.               printf("Found a better match..reinitializing tracking\n");  
  387.               bbnext=cbb[didx];  
  388.               lastconf=cconf[didx];  
  389.               lastvalid=false;  
  390.           }  
  391.           else {  
  392.               printf("%d confident cluster was found\n", confident_detections);  
  393.               int cx=0,cy=0,cw=0,ch=0;  
  394.               int close_detections=0;  
  395.               for (int i=0;i<dbb.size();i++){  
  396.                   //找到检测器检测到的box与跟踪器预测到的box距离很近(重叠度大于0.7)的box,对其坐标和大小进行累加  
  397.                   if(bbOverlap(tbb,dbb[i])>0.7){                     // Get mean of close detections  
  398.                       cx += dbb[i].x;  
  399.                       cy +=dbb[i].y;  
  400.                       cw += dbb[i].width;  
  401.                       ch += dbb[i].height;  
  402.                       close_detections++;   //记录最近邻box的个数  
  403.                       printf("weighted detection: %d %d %d %d\n",dbb[i].x,dbb[i].y,dbb[i].width,dbb[i].height);  
  404.                   }  
  405.               }  
  406.               if (close_detections>0){  
  407.                   //对与跟踪器预测到的box距离很近的box 和 跟踪器本身预测到的box 进行坐标与大小的平均作为最终的  
  408.                   //目标bounding box,但是跟踪器的权值较大  
  409.                   bbnext.x = cvRound((float)(10*tbb.x+cx)/(float)(10+close_detections));   // weighted average trackers trajectory with the close detections  
  410.                   bbnext.y = cvRound((float)(10*tbb.y+cy)/(float)(10+close_detections));  
  411.                   bbnext.width = cvRound((float)(10*tbb.width+cw)/(float)(10+close_detections));  
  412.                   bbnext.height =  cvRound((float)(10*tbb.height+ch)/(float)(10+close_detections));  
  413.                   printf("Tracker bb: %d %d %d %d\n",tbb.x,tbb.y,tbb.width,tbb.height);  
  414.                   printf("Average bb: %d %d %d %d\n",bbnext.x,bbnext.y,bbnext.width,bbnext.height);  
  415.                   printf("Weighting %d close detection(s) with tracker..\n",close_detections);  
  416.               }  
  417.               else{  
  418.                 printf("%d close detections were found\n",close_detections);  
  419.   
  420.               }  
  421.           }  
  422.       }  
  423.   }  
  424.   else{                                       //   If NOT tracking  
  425.       printf("Not tracking..\n");  
  426.       lastboxfound = false;  
  427.       lastvalid = false;  
  428.       //如果跟踪器没有跟踪到目标,但是检测器检测到了一些可能的目标box,那么同样对其进行聚类,但只是简单的  
  429.       //将聚类的cbb[0]作为新的跟踪目标box(不比较相似度了??还是里面已经排好序了??),重新初始化跟踪器  
  430.       if(detected){                           //  and detector is defined  
  431.           clusterConf(dbb,dconf,cbb,cconf);   //  cluster detections  
  432.           printf("Found %d clusters\n",(int)cbb.size());  
  433.           if (cconf.size()==1){  
  434.               bbnext=cbb[0];  
  435.               lastconf=cconf[0];  
  436.               printf("Confident detection..reinitializing tracker\n");  
  437.               lastboxfound = true;  
  438.           }  
  439.       }  
  440.   }  
  441.   lastbox=bbnext;  
  442.   if (lastboxfound)  
  443.     fprintf(bb_file,"%d,%d,%d,%d,%f\n",lastbox.x,lastbox.y,lastbox.br().x,lastbox.br().y,lastconf);  
  444.   else  
  445.     fprintf(bb_file,"NaN,NaN,NaN,NaN,NaN\n");  
  446.       
  447.   ///learn 学习模块  
  448.   if (lastvalid && tl)  
  449.     learn(img2);  
  450. }  
  451.   
  452. /*Inputs: 
  453. * -current frame(img2), last frame(img1), last Bbox(bbox_f[0]). 
  454. *Outputs: 
  455. *- Confidence(tconf), Predicted bounding box(tbb), Validity(tvalid), points2 (for display purposes only) 
  456. */  
  457. void TLD::track(const Mat& img1, const Mat& img2, vector<Point2f>& points1, vector<Point2f>& points2){  
  458.     
  459.   //Generate points  
  460.   //网格均匀撒点(均匀采样),在lastbox中共产生最多10*10=100个特征点,存于points1  
  461.   bbPoints(points1, lastbox);  
  462.   if (points1.size()<1){  
  463.       printf("BB= %d %d %d %d, Points not generated\n",lastbox.x,lastbox.y,lastbox.width,lastbox.height);  
  464.       tvalid=false;  
  465.       tracked=false;  
  466.       return;  
  467.   }  
  468.   vector<Point2f> points = points1;  
  469.     
  470.   //Frame-to-frame tracking with forward-backward error cheking  
  471.   //trackf2f函数完成:跟踪、计算FB error和匹配相似度sim,然后筛选出 FB_error[i] <= median(FB_error) 和   
  472.   //sim_error[i] > median(sim_error) 的特征点(跟踪结果不好的特征点),剩下的是不到50%的特征点  
  473.   tracked = tracker.trackf2f(img1, img2, points, points2);  
  474.   if (tracked){  
  475.       //Bounding box prediction  
  476.       //利用剩下的这不到一半的跟踪点输入来预测bounding box在当前帧的位置和大小 tbb  
  477.       bbPredict(points, points2, lastbox, tbb);  
  478.       //跟踪失败检测:如果FB error的中值大于10个像素(经验值),或者预测到的当前box的位置移出图像,则  
  479.       //认为跟踪错误,此时不返回bounding box;Rect::br()返回的是右下角的坐标  
  480.       //getFB()返回的是FB error的中值  
  481.       if (tracker.getFB()>10 || tbb.x>img2.cols ||  tbb.y>img2.rows || tbb.br().x < 1 || tbb.br().y <1){  
  482.           tvalid =false//too unstable prediction or bounding box out of image  
  483.           tracked = false;  
  484.           printf("Too unstable predictions FB error=%f\n", tracker.getFB());  
  485.           return;  
  486.       }  
  487.         
  488.       //Estimate Confidence and Validity  
  489.       //评估跟踪确信度和有效性  
  490.       Mat pattern;  
  491.       Scalar mean, stdev;  
  492.       BoundingBox bb;  
  493.       bb.x = max(tbb.x,0);  
  494.       bb.y = max(tbb.y,0);  
  495.       bb.width = min(min(img2.cols-tbb.x,tbb.width), min(tbb.width, tbb.br().x));  
  496.       bb.height = min(min(img2.rows-tbb.y,tbb.height),min(tbb.height,tbb.br().y));  
  497.       //归一化img2(bb)对应的patch的size(放缩至patch_size = 15*15),存入pattern  
  498.       getPattern(img2(bb),pattern,mean,stdev);  
  499.       vector<int> isin;  
  500.       float dummy;  
  501.       //计算图像片pattern到在线模型M的保守相似度  
  502.       classifier.NNConf(pattern,isin,dummy,tconf); //Conservative Similarity  
  503.       tvalid = lastvalid;  
  504.       //保守相似度大于阈值,则评估跟踪有效  
  505.       if (tconf>classifier.thr_nn_valid){  
  506.           tvalid =true;  
  507.       }  
  508.   }  
  509.   else  
  510.     printf("No points tracked\n");  
  511.   
  512. }  
  513.   
  514. //网格均匀撒点,box共10*10=100个特征点  
  515. void TLD::bbPoints(vector<cv::Point2f>& points, const BoundingBox& bb){  
  516.   int max_pts=10;  
  517.   int margin_h=0; //采样边界  
  518.   int margin_v=0;  
  519.   //网格均匀撒点  
  520.   int stepx = ceil((bb.width-2*margin_h)/max_pts);  //ceil返回大于或者等于指定表达式的最小整数  
  521.   int stepy = ceil((bb.height-2*margin_v)/max_pts);  
  522.   //网格均匀撒点,box共10*10=100个特征点  
  523.   for (int y=bb.y+margin_v; y<bb.y+bb.height-margin_v; y+=stepy){  
  524.       for (int x=bb.x+margin_h;x<bb.x+bb.width-margin_h;x+=stepx){  
  525.           points.push_back(Point2f(x,y));  
  526.       }  
  527.   }  
  528. }  
  529.   
  530. //利用剩下的这不到一半的跟踪点输入来预测bounding box在当前帧的位置和大小  
  531. void TLD::bbPredict(const vector<cv::Point2f>& points1,const vector<cv::Point2f>& points2,  
  532.                     const BoundingBox& bb1,BoundingBox& bb2)    {  
  533.   int npoints = (int)points1.size();  
  534.   vector<float> xoff(npoints);  //位移  
  535.   vector<float> yoff(npoints);  
  536.   printf("tracked points : %d\n", npoints);  
  537.   for (int i=0;i<npoints;i++){   //计算每个特征点在两帧之间的位移  
  538.       xoff[i]=points2[i].x - points1[i].x;  
  539.       yoff[i]=points2[i].y - points1[i].y;  
  540.   }  
  541.   float dx = median(xoff);   //计算位移的中值  
  542.   float dy = median(yoff);  
  543.   float s;  
  544.   //计算bounding box尺度scale的变化:通过计算 当前特征点相互间的距离 与 先前(上一帧)特征点相互间的距离 的  
  545.   //比值,以比值的中值作为尺度的变化因子  
  546.   if (npoints>1){  
  547.       vector<float> d;  
  548.       d.reserve(npoints*(npoints-1)/2);  //等差数列求和:1+2+...+(npoints-1)  
  549.       for (int i=0;i<npoints;i++){  
  550.           for (int j=i+1;j<npoints;j++){  
  551.           //计算 当前特征点相互间的距离 与 先前(上一帧)特征点相互间的距离 的比值(位移用绝对值)  
  552.               d.push_back(norm(points2[i]-points2[j])/norm(points1[i]-points1[j]));  
  553.           }  
  554.       }  
  555.       s = median(d);  
  556.   }  
  557.   else {  
  558.       s = 1.0;  
  559.   }  
  560.   
  561.   float s1 = 0.5*(s-1)*bb1.width;  
  562.   float s2 = 0.5*(s-1)*bb1.height;  
  563.   printf("s= %f s1= %f s2= %f \n", s, s1, s2);  
  564.     
  565.   //得到当前bounding box的位置与大小信息  
  566.   //当前box的x坐标 = 前一帧box的x坐标 + 全部特征点位移的中值(可理解为box移动近似的位移) - 当前box宽的一半  
  567.   bb2.x = round( bb1.x + dx - s1);  
  568.   bb2.y = round( bb1.y + dy -s2);  
  569.   bb2.width = round(bb1.width*s);  
  570.   bb2.height = round(bb1.height*s);  
  571.   printf("predicted bb: %d %d %d %d\n",bb2.x,bb2.y,bb2.br().x,bb2.br().y);  
  572. }  
  573.   
  574. void TLD::detect(const cv::Mat& frame){  
  575.   //cleaning  
  576.   dbb.clear();  
  577.   dconf.clear();  
  578.   dt.bb.clear();  
  579.   //GetTickCount返回从操作系统启动到现在所经过的时间  
  580.   double t = (double)getTickCount();  
  581.   Mat img(frame.rows, frame.cols, CV_8U);  
  582.   integral(frame,iisum,iisqsum);   //计算frame的积分图   
  583.   GaussianBlur(frame,img,Size(9,9),1.5);  //高斯模糊,去噪?  
  584.   int numtrees = classifier.getNumStructs();  
  585.   float fern_th = classifier.getFernTh(); //getFernTh()返回thr_fern; 集合分类器的分类阈值  
  586.   vector <int> ferns(10);  
  587.   float conf;  
  588.   int a=0;  
  589.   Mat patch;  
  590.   //级联分类器模块一:方差检测模块,利用积分图计算每个待检测窗口的方差,方差大于var阈值(目标patch方差的50%)的,  
  591.   //则认为其含有前景目标  
  592.   for (int i=0; i<grid.size(); i++){  //FIXME: BottleNeck 瓶颈  
  593.       if (getVar(grid[i],iisum,iisqsum) >= var){  //计算每一个扫描窗口的方差  
  594.           a++;  
  595.           //级联分类器模块二:集合分类器检测模块  
  596.           patch = img(grid[i]);  
  597.           classifier.getFeatures(patch,grid[i].sidx,ferns); //得到该patch特征(13位的二进制代码)  
  598.           conf = classifier.measure_forest(ferns);  //计算该特征值对应的后验概率累加值  
  599.           tmp.conf[i]=conf;   //Detector data中定义TempStruct tmp;   
  600.           tmp.patt[i]=ferns;  
  601.           //如果集合分类器的后验概率的平均值大于阈值fern_th(由训练得到),就认为含有前景目标  
  602.           if (conf > numtrees*fern_th){    
  603.               dt.bb.push_back(i);  //将通过以上两个检测模块的扫描窗口记录在detect structure中  
  604.           }  
  605.       }  
  606.       else  
  607.         tmp.conf[i]=0.0;  
  608.   }  
  609.   int detections = dt.bb.size();  
  610.   printf("%d Bounding boxes passed the variance filter\n",a);  
  611.   printf("%d Initial detection from Fern Classifier\n", detections);  
  612.     
  613.   //如果通过以上两个检测模块的扫描窗口数大于100个,则只取后验概率大的前100个  
  614.   if (detections>100){   //CComparator(tmp.conf)指定比较方式???  
  615.       nth_element(dt.bb.begin(), dt.bb.begin()+100, dt.bb.end(), CComparator(tmp.conf));  
  616.       dt.bb.resize(100);  
  617.       detections=100;  
  618.   }  
  619. //  for (int i=0;i<detections;i++){  
  620. //        drawBox(img,grid[dt.bb[i]]);  
  621. //    }  
  622. //  imshow("detections",img);  
  623.   if (detections==0){  
  624.         detected=false;  
  625.         return;  
  626.       }  
  627.   printf("Fern detector made %d detections ",detections);  
  628.     
  629.   //两次使用getTickCount(),然后再除以getTickFrequency(),计算出来的是以秒s为单位的时间(opencv 2.0 以前是ms)  
  630.   t=(double)getTickCount()-t;    
  631.   printf("in %gms\n", t*1000/getTickFrequency());  //打印以上代码运行使用的毫秒数  
  632.     
  633.   //  Initialize detection structure  
  634.   dt.patt = vector<vector<int> >(detections,vector<int>(10,0));        //  Corresponding codes of the Ensemble Classifier  
  635.   dt.conf1 = vector<float>(detections);                                //  Relative Similarity (for final nearest neighbour classifier)  
  636.   dt.conf2 =vector<float>(detections);                                 //  Conservative Similarity (for integration with tracker)  
  637.   dt.isin = vector<vector<int> >(detections,vector<int>(3,-1));        //  Detected (isin=1) or rejected (isin=0) by nearest neighbour classifier  
  638.   dt.patch = vector<Mat>(detections,Mat(patch_size,patch_size,CV_32F));//  Corresponding patches  
  639.   int idx;  
  640.   Scalar mean, stdev;  
  641.   float nn_th = classifier.getNNTh();  
  642.   //级联分类器模块三:最近邻分类器检测模块  
  643.   for (int i=0;i<detections;i++){                                         //  for every remaining detection  
  644.       idx=dt.bb[i];                                                       //  Get the detected bounding box index  
  645.       patch = frame(grid[idx]);  
  646.       getPattern(patch,dt.patch[i],mean,stdev);                //  Get pattern within bounding box  
  647.       //计算图像片pattern到在线模型M的相关相似度和保守相似度  
  648.       classifier.NNConf(dt.patch[i],dt.isin[i],dt.conf1[i],dt.conf2[i]);  //  Evaluate nearest neighbour classifier  
  649.       dt.patt[i]=tmp.patt[idx];  
  650.       //printf("Testing feature %d, conf:%f isin:(%d|%d|%d)\n",i,dt.conf1[i],dt.isin[i][0],dt.isin[i][1],dt.isin[i][2]);  
  651.       //相关相似度大于阈值,则认为含有前景目标  
  652.       if (dt.conf1[i]>nn_th){                                               //  idx = dt.conf1 > tld.model.thr_nn; % get all indexes that made it through the nearest neighbour  
  653.           dbb.push_back(grid[idx]);                                         //  BB    = dt.bb(:,idx); % bounding boxes  
  654.           dconf.push_back(dt.conf2[i]);                                     //  Conf  = dt.conf2(:,idx); % conservative confidences  
  655.       }  
  656.   }  
  657.   //打印检测到的可能存在目标的扫描窗口数(可以通过三个级联检测器的)  
  658.   if (dbb.size()>0){  
  659.       printf("Found %d NN matches\n",(int)dbb.size());  
  660.       detected=true;  
  661.   }  
  662.   else{  
  663.       printf("No NN matches found.\n");  
  664.       detected=false;  
  665.   }  
  666. }  
  667.   
  668. //作者已经用python脚本../datasets/evaluate_vis.py来完成算法评估功能,具体见README  
  669. void TLD::evaluate(){  
  670. }  
  671.   
  672. void TLD::learn(const Mat& img){  
  673.   printf("[Learning] ");  
  674.     
  675.   ///Check consistency  
  676.   //检测一致性  
  677.   BoundingBox bb;  
  678.   bb.x = max(lastbox.x,0);  
  679.   bb.y = max(lastbox.y,0);  
  680.   bb.width = min(min(img.cols-lastbox.x,lastbox.width),min(lastbox.width,lastbox.br().x));  
  681.   bb.height = min(min(img.rows-lastbox.y,lastbox.height),min(lastbox.height,lastbox.br().y));  
  682.   Scalar mean, stdev;  
  683.   Mat pattern;  
  684.   //归一化img(bb)对应的patch的size(放缩至patch_size = 15*15),存入pattern  
  685.   getPattern(img(bb), pattern, mean, stdev);  
  686.   vector<int> isin;  
  687.   float dummy, conf;  
  688.   //计算输入图像片(跟踪器的目标box)与在线模型之间的相关相似度conf  
  689.   classifier.NNConf(pattern,isin,conf,dummy);  
  690.   if (conf<0.5) {   //如果相似度太小了,就不训练  
  691.       printf("Fast change..not training\n");  
  692.       lastvalid =false;  
  693.       return;  
  694.   }  
  695.   if (pow(stdev.val[0], 2)< var){  //如果方差太小了,也不训练  
  696.       printf("Low variance..not training\n");  
  697.       lastvalid=false;  
  698.       return;  
  699.   }  
  700.   if(isin[2]==1){   //如果被被识别为负样本,也不训练  
  701.       printf("Patch in negative data..not traing");  
  702.       lastvalid=false;  
  703.       return;  
  704.   }  
  705.     
  706.   /// Data generation  样本产生  
  707.   for (int i=0;i<grid.size();i++){   //计算所有的扫描窗口与目标box的重叠度  
  708.       grid[i].overlap = bbOverlap(lastbox, grid[i]);  
  709.   }  
  710.   //集合分类器  
  711.   vector<pair<vector<int>,int> > fern_examples;  
  712.   good_boxes.clear();    
  713.   bad_boxes.clear();  
  714.   //此函数根据传入的lastbox,在整帧图像中的全部窗口中寻找与该lastbox距离最小(即最相似,  
  715.   //重叠度最大)的num_closest_update个窗口,然后把这些窗口 归入good_boxes容器(只是把网格数组的索引存入)  
  716.   //同时,把重叠度小于0.2的,归入 bad_boxes 容器  
  717.   getOverlappingBoxes(lastbox, num_closest_update);  
  718.   if (good_boxes.size()>0)  
  719.     generatePositiveData(img, num_warps_update);  //用仿射模型产生正样本(类似于第一帧的方法,但只产生10*10=100个)  
  720.   else{  
  721.     lastvalid = false;  
  722.     printf("No good boxes..Not training");  
  723.     return;  
  724.   }  
  725.   fern_examples.reserve(pX.size() + bad_boxes.size());  
  726.   fern_examples.assign(pX.begin(), pX.end());  
  727.   int idx;  
  728.   for (int i=0;i<bad_boxes.size();i++){  
  729.       idx=bad_boxes[i];  
  730.       if (tmp.conf[idx]>=1){   //加入负样本,相似度大于1??相似度不是出于0和1之间吗?  
  731.           fern_examples.push_back(make_pair(tmp.patt[idx],0));  
  732.       }  
  733.   }  
  734.   //最近邻分类器  
  735.   vector<Mat> nn_examples;  
  736.   nn_examples.reserve(dt.bb.size()+1);  
  737.   nn_examples.push_back(pEx);  
  738.   for (int i=0;i<dt.bb.size();i++){  
  739.       idx = dt.bb[i];  
  740.       if (bbOverlap(lastbox,grid[idx]) < bad_overlap)  
  741.         nn_examples.push_back(dt.patch[i]);  
  742.   }  
  743.     
  744.   /// Classifiers update  分类器训练  
  745.   classifier.trainF(fern_examples,2);  
  746.   classifier.trainNN(nn_examples);  
  747.   classifier.show(); //把正样本库(在线模型)包含的所有正样本显示在窗口上  
  748. }  
  749.   
  750. //检测器采用扫描窗口的策略  
  751. //此函数根据传入的box(目标边界框)在传入的图像中构建全部的扫描窗口,并计算每个窗口与box的重叠度  
  752. void TLD::buildGrid(const cv::Mat& img, const cv::Rect& box){  
  753.   const float SHIFT = 0.1;  //扫描窗口步长为 宽高的 10%  
  754.   //尺度缩放系数为1.2 (0.16151*1.2=0.19381),共21种尺度变换  
  755.   const float SCALES[] = {0.16151,0.19381,0.23257,0.27908,0.33490,0.40188,0.48225,  
  756.                           0.57870,0.69444,0.83333,1,1.20000,1.44000,1.72800,  
  757.                           2.07360,2.48832,2.98598,3.58318,4.29982,5.15978,6.19174};  
  758.   int width, height, min_bb_side;  
  759.   //Rect bbox;  
  760.   BoundingBox bbox;  
  761.   Size scale;  
  762.   int sc=0;  
  763.     
  764.   for (int s=0; s < 21; s++){  
  765.     width = round(box.width*SCALES[s]);  
  766.     height = round(box.height*SCALES[s]);  
  767.     min_bb_side = min(height,width);  //bounding box最短的边  
  768.     //由于图像片(min_win 为15x15像素)是在bounding box中采样得到的,所以box必须比min_win要大  
  769.     //另外,输入的图像肯定得比 bounding box 要大了  
  770.     if (min_bb_side < min_win || width > img.cols || height > img.rows)  
  771.       continue;  
  772.     scale.width = width;  
  773.     scale.height = height;  
  774.     //push_back在vector类中作用为在vector尾部加入一个数据  
  775.     //scales在类TLD中定义:std::vector<cv::Size> scales;  
  776.     scales.push_back(scale);  //把该尺度的窗口存入scales容器,避免在扫描时计算,加快检测速度  
  777.     for (int y=1; y<img.rows-height; y+=round(SHIFT*min_bb_side)){  //按步长移动窗口  
  778.       for (int x=1; x<img.cols-width; x+=round(SHIFT*min_bb_side)){  
  779.         bbox.x = x;  
  780.         bbox.y = y;  
  781.         bbox.width = width;  
  782.         bbox.height = height;  
  783.         //判断传入的bounding box(目标边界框)与 传入图像中的此时窗口的 重叠度,  
  784.         //以此来确定该图像窗口是否含有目标  
  785.         bbox.overlap = bbOverlap(bbox, BoundingBox(box));  
  786.         bbox.sidx = sc;  //属于第几个尺度  
  787.         //grid在类TLD中定义:std::vector<BoundingBox> grid;  
  788.         //把本位置和本尺度的扫描窗口存入grid容器  
  789.         grid.push_back(bbox);  
  790.       }  
  791.     }  
  792.     sc++;  
  793.   }  
  794. }  
  795.   
  796. //此函数计算两个bounding box 的重叠度  
  797. //重叠度定义为 两个box的交集 与 它们的并集 的比  
  798. float TLD::bbOverlap(const BoundingBox& box1, const BoundingBox& box2){  
  799.   //先判断坐标,假如它们都没有重叠的地方,就直接返回0  
  800.   if (box1.x > box2.x + box2.width) { return 0.0; }  
  801.   if (box1.y > box2.y + box2.height) { return 0.0; }  
  802.   if (box1.x + box1.width < box2.x) { return 0.0; }  
  803.   if (box1.y + box1.height < box2.y) { return 0.0; }  
  804.   
  805.   float colInt =  min(box1.x + box1.width, box2.x + box2.width) - max(box1.x, box2.x);  
  806.   float rowInt =  min(box1.y + box1.height, box2.y + box2.height) - max(box1.y, box2.y);  
  807.   
  808.   float intersection = colInt * rowInt;  
  809.   float area1 = box1.width * box1.height;  
  810.   float area2 = box2.width * box2.height;  
  811.   return intersection / (area1 + area2 - intersection);  
  812. }  
  813.   
  814. //此函数根据传入的box1(目标边界框),在整帧图像中的全部窗口中寻找与该box1距离最小(即最相似,  
  815. //重叠度最大)的num_closest个窗口,然后把这些窗口 归入good_boxes容器(只是把网格数组的索引存入)  
  816. //同时,把重叠度小于0.2的,归入 bad_boxes 容器  
  817. void TLD::getOverlappingBoxes(const cv::Rect& box1,int num_closest){  
  818.   float max_overlap = 0;  
  819.   for (int i=0;i<grid.size();i++){  
  820.       if (grid[i].overlap > max_overlap) {  //找出重叠度最大的box  
  821.           max_overlap = grid[i].overlap;  
  822.           best_box = grid[i];         
  823.       }  
  824.       if (grid[i].overlap > 0.6){   //重叠度大于0.6的,归入 good_boxes  
  825.           good_boxes.push_back(i);  
  826.       }  
  827.       else if (grid[i].overlap < bad_overlap){  //重叠度小于0.2的,归入 bad_boxes  
  828.           bad_boxes.push_back(i);  
  829.       }  
  830.   }  
  831.   //Get the best num_closest (10) boxes and puts them in good_boxes  
  832.   if (good_boxes.size()>num_closest){  
  833.   //STL中的nth_element()方法找出一个数列中排名第n(下面为第num_closest)的那个数。这个函数运行后  
  834.   //在good_boxes[num_closest]前面num_closest个数都比他大,也就是找到最好的num_closest个box了  
  835.     std::nth_element(good_boxes.begin(), good_boxes.begin() + num_closest, good_boxes.end(), OComparator(grid));  
  836.     //重新压缩good_boxes为num_closest大小  
  837.     good_boxes.resize(num_closest);  
  838.   }  
  839.   //获取good_boxes 的 Hull壳,也就是窗口的边框  
  840.   getBBHull();  
  841. }  
  842.   
  843. //此函数获取good_boxes 的 Hull壳,也就是窗口(图像)的边框 bounding box  
  844. void TLD::getBBHull(){  
  845.   int x1=INT_MAX, x2=0;  //INT_MAX 最大的整形数  
  846.   int y1=INT_MAX, y2=0;  
  847.   int idx;  
  848.   for (int i=0;i<good_boxes.size();i++){  
  849.       idx= good_boxes[i];  
  850.       x1=min(grid[idx].x,x1);   //防止出现负数??  
  851.       y1=min(grid[idx].y,y1);  
  852.       x2=max(grid[idx].x + grid[idx].width,x2);  
  853.       y2=max(grid[idx].y + grid[idx].height,y2);  
  854.   }  
  855.   bbhull.x = x1;  
  856.   bbhull.y = y1;  
  857.   bbhull.width = x2-x1;  
  858.   bbhull.height = y2 -y1;  
  859. }  
  860.   
  861. //如果两个box的重叠度小于0.5,返回false,否则返回true  
  862. bool bbcomp(const BoundingBox& b1,const BoundingBox& b2){  
  863.   TLD t;  
  864.     if (t.bbOverlap(b1,b2)<0.5)  
  865.       return false;  
  866.     else  
  867.       return true;  
  868. }  
  869.   
  870. int TLD::clusterBB(const vector<BoundingBox>& dbb,vector<int>& indexes){  
  871.   //FIXME: Conditional jump or move depends on uninitialised value(s)  
  872.   const int c = dbb.size();  
  873.   //1. Build proximity matrix  
  874.   Mat D(c,c,CV_32F);  
  875.   float d;  
  876.   for (int i=0;i<c;i++){  
  877.       for (int j=i+1;j<c;j++){  
  878.         d = 1-bbOverlap(dbb[i],dbb[j]);  
  879.         D.at<float>(i,j) = d;  
  880.         D.at<float>(j,i) = d;  
  881.       }  
  882.   }  
  883.   //2. Initialize disjoint clustering  
  884.  float L[c-1]; //Level  
  885.  int nodes[c-1][2];  
  886.  int belongs[c];  
  887.  int m=c;  
  888.  for (int i=0;i<c;i++){  
  889.     belongs[i]=i;  
  890.  }  
  891.  for (int it=0;it<c-1;it++){  
  892.  //3. Find nearest neighbor  
  893.      float min_d = 1;  
  894.      int node_a, node_b;  
  895.      for (int i=0;i<D.rows;i++){  
  896.          for (int j=i+1;j<D.cols;j++){  
  897.              if (D.at<float>(i,j)<min_d && belongs[i]!=belongs[j]){  
  898.                  min_d = D.at<float>(i,j);  
  899.                  node_a = i;  
  900.                  node_b = j;  
  901.              }  
  902.          }  
  903.      }  
  904.      if (min_d>0.5){  
  905.          int max_idx =0;  
  906.          bool visited;  
  907.          for (int j=0;j<c;j++){  
  908.              visited = false;  
  909.              for(int i=0;i<2*c-1;i++){  
  910.                  if (belongs[j]==i){  
  911.                      indexes[j]=max_idx;  
  912.                      visited = true;  
  913.                  }  
  914.              }  
  915.              if (visited)  
  916.                max_idx++;  
  917.          }  
  918.          return max_idx;  
  919.      }  
  920.   
  921.  //4. Merge clusters and assign level  
  922.      L[m]=min_d;  
  923.      nodes[it][0] = belongs[node_a];  
  924.      nodes[it][1] = belongs[node_b];  
  925.      for (int k=0;k<c;k++){  
  926.          if (belongs[k]==belongs[node_a] || belongs[k]==belongs[node_b])  
  927.            belongs[k]=m;  
  928.      }  
  929.      m++;  
  930.  }  
  931.  return 1;  
  932.   
  933. }  
  934.   
  935. //对检测器检测到的目标bounding box进行聚类  
  936. //聚类(Cluster)分析是由若干模式(Pattern)组成的,通常,模式是一个度量(Measurement)的向量,或者是多维空间中的  
  937. //一个点。聚类分析以相似性为基础,在一个聚类中的模式之间比不在同一聚类中的模式之间具有更多的相似性。  
  938. void TLD::clusterConf(const vector<BoundingBox>& dbb,const vector<float>& dconf,vector<BoundingBox>& cbb,vector<float>& cconf){  
  939.   int numbb =dbb.size();  
  940.   vector<int> T;  
  941.   float space_thr = 0.5;  
  942.   int c=1;    //记录 聚类的类个数  
  943.   switch (numbb){  //检测到的含有目标的bounding box个数  
  944.   case 1:  
  945.     cbb=vector<BoundingBox>(1,dbb[0]);  //如果只检测到一个,那么这个就是检测器检测到的目标  
  946.     cconf=vector<float>(1,dconf[0]);  
  947.     return;  
  948.     break;  
  949.   case 2:  
  950.     T =vector<int>(2,0);  
  951.     //此函数计算两个bounding box 的重叠度  
  952.     if (1 - bbOverlap(dbb[0],dbb[1]) > space_thr){  //如果只检测到两个box,但他们的重叠度小于0.5  
  953.       T[1]=1;  
  954.       c=2;  //重叠度小于0.5的box,属于不同的类  
  955.     }  
  956.     break;  
  957.   default:  //检测到的box数目大于2个,则筛选出重叠度大于0.5的  
  958.     T = vector<int>(numbb, 0);  
  959.     //stable_partition()重新排列元素,使得满足指定条件的元素排在不满足条件的元素前面。它维持着两组元素的顺序关系。  
  960.     //STL partition就是把一个区间中的元素按照某个条件分成两类。返回第二类子集的起点  
  961.     //bbcomp()函数判断两个box的重叠度小于0.5,返回false,否则返回true (分界点是重叠度:0.5)  
  962.     //partition() 将dbb划分为两个子集,将满足两个box的重叠度小于0.5的元素移动到序列的前面,为一个子集,重叠度大于0.5的,  
  963.     //放在序列后面,为第二个子集,但两个子集的大小不知道,返回第二类子集的起点  
  964.     c = partition(dbb, T, (*bbcomp));   //重叠度小于0.5的box,属于不同的类,所以c是不同的类别个数  
  965.     //c = clusterBB(dbb,T);  
  966.     break;  
  967.   }  
  968.     
  969.   cconf=vector<float>(c);   
  970.   cbb=vector<BoundingBox>(c);  
  971.   printf("Cluster indexes: ");  
  972.   BoundingBox bx;  
  973.   for (int i=0;i<c;i++){   //类别个数  
  974.       float cnf=0;  
  975.       int N=0,mx=0,my=0,mw=0,mh=0;  
  976.       for (int j=0;j<T.size();j++){  //检测到的bounding box个数  
  977.           if (T[j]==i){   //将聚类为同一个类别的box的坐标和大小进行累加  
  978.               printf("%d ",i);  
  979.               cnf=cnf+dconf[j];  
  980.               mx=mx+dbb[j].x;  
  981.               my=my+dbb[j].y;  
  982.               mw=mw+dbb[j].width;  
  983.               mh=mh+dbb[j].height;  
  984.               N++;  
  985.           }  
  986.       }  
  987.       if (N>0){   //然后求该类的box的坐标和大小的平均值,将平均值作为该类的box的代表  
  988.           cconf[i]=cnf/N;  
  989.           bx.x=cvRound(mx/N);  
  990.           bx.y=cvRound(my/N);  
  991.           bx.width=cvRound(mw/N);  
  992.           bx.height=cvRound(mh/N);  
  993.           cbb[i]=bx;  //返回的是聚类,每一个类都有一个代表的bounding box  
  994.       }  
  995.   }  
  996.   printf("\n");  
  997. }  

下面是自己在看论文和这些大牛的分析过程中,对代码进行了一些理解,但是由于自己接触图像处理和机器视觉没多久,另外由于自己编程能力比较弱,所以分析过程可能会有不少的错误,希望各位不吝指正。而且,因为编程很多地方不懂,所以注释得非常乱,还海涵。

 

FerNNClassifier.h

[cpp]  view plain copy
  1. /* 
  2.  * FerNNClassifier.h 
  3.  * 
  4.  *  Created on: Jun 14, 2011 
  5.  *      Author: alantrrs 
  6.  */  
  7.   
  8. #include <opencv2/opencv.hpp>  
  9. #include <stdio.h>  
  10. class FerNNClassifier{  
  11. private:  
  12.   //下面这些参数通过程序开始运行时读入parameters.yml文件进行初始化  
  13.   float thr_fern;  
  14.   int structSize;  
  15.   int nstructs;  
  16.   float valid;  
  17.   float ncc_thesame;  
  18.   float thr_nn;  
  19.   int acum;  
  20. public:  
  21.   //Parameters  
  22.   float thr_nn_valid;  
  23.   
  24.   void read(const cv::FileNode& file);  
  25.   void prepare(const std::vector<cv::Size>& scales);  
  26.   void getFeatures(const cv::Mat& image,const int& scale_idx,std::vector<int>& fern);  
  27.   void update(const std::vector<int>& fern, int C, int N);  
  28.   float measure_forest(std::vector<int> fern);  
  29.   void trainF(const std::vector<std::pair<std::vector<int>,int> >& ferns,int resample);  
  30.   void trainNN(const std::vector<cv::Mat>& nn_examples);  
  31.   void NNConf(const cv::Mat& example,std::vector<int>& isin,float& rsconf,float& csconf);  
  32.   void evaluateTh(const std::vector<std::pair<std::vector<int>,int> >& nXT,const std::vector<cv::Mat>& nExT);  
  33.   void show();  
  34.   //Ferns Members  
  35.   int getNumStructs(){return nstructs;}  
  36.   float getFernTh(){return thr_fern;}  
  37.   float getNNTh(){return thr_nn;}  
  38.     
  39.   struct Feature   //特征结构体  
  40.       {  
  41.           uchar x1, y1, x2, y2;  
  42.           Feature() : x1(0), y1(0), x2(0), y2(0) {}  
  43.           Feature(int _x1, int _y1, int _x2, int _y2)  
  44.           : x1((uchar)_x1), y1((uchar)_y1), x2((uchar)_x2), y2((uchar)_y2)  
  45.           {}  
  46.           bool operator ()(const cv::Mat& patch) const  
  47.           {   
  48.             //二维单通道元素可以用Mat::at(i, j)访问,i是行序号,j是列序号  
  49.             //返回的patch图像片在(y1,x1)和(y2, x2)点的像素比较值,返回0或者1  
  50.             return patch.at<uchar>(y1,x1) > patch.at<uchar>(y2, x2);   
  51.           }  
  52.       };  
  53.   //Ferns(蕨类植物:有根、茎、叶之分,不具花)features 特征组?  
  54.   std::vector<std::vector<Feature> > features; //Ferns features (one std::vector for each scale)  
  55.   std::vector< std::vector<int> > nCounter; //negative counter  
  56.   std::vector< std::vector<int> > pCounter; //positive counter  
  57.   std::vector< std::vector<float> > posteriors; //Ferns posteriors  
  58.   float thrN; //Negative threshold  
  59.   float thrP;  //Positive thershold  
  60.     
  61.   //NN Members  
  62.   std::vector<cv::Mat> pEx; //NN positive examples  
  63.   std::vector<cv::Mat> nEx; //NN negative examples  
  64. };  


 

 

FerNNClassifier.cpp

[cpp]  view plain copy
  1. /* 
  2.  * FerNNClassifier.cpp 
  3.  * 
  4.  *  Created on: Jun 14, 2011 
  5.  *      Author: alantrrs 
  6.  */  
  7.   
  8. #include <FerNNClassifier.h>  
  9.   
  10. using namespace cv;  
  11. using namespace std;  
  12.   
  13. void FerNNClassifier::read(const FileNode& file){  
  14.   ///Classifier Parameters  
  15.   //下面这些参数通过程序开始运行时读入parameters.yml文件进行初始化  
  16.   valid = (float)file["valid"];  
  17.   ncc_thesame = (float)file["ncc_thesame"];  
  18.   nstructs = (int)file["num_trees"];   //树木(由一个特征组构建,每组特征代表图像块的不同视图表示)的个数  
  19.   structSize = (int)file["num_features"];  //每棵树的特征个数,也即每棵树的节点个数;树上每一个特征都作为一个决策节点  
  20.   thr_fern = (float)file["thr_fern"];  
  21.   thr_nn = (float)file["thr_nn"];  
  22.   thr_nn_valid = (float)file["thr_nn_valid"];  
  23. }  
  24.   
  25. void FerNNClassifier::prepare(const vector<Size>& scales){  
  26.   acum = 0;  
  27.   //Initialize test locations for features  
  28.   int totalFeatures = nstructs * structSize;  
  29.   //二维向量  包含全部尺度(scales)的扫描窗口,每个尺度包含totalFeatures个特征  
  30.   features = vector<vector<Feature> >(scales.size(), vector<Feature> (totalFeatures));  
  31.    
  32.   //opencv中自带的一个随机数发生器的类RNG  
  33.   RNG& rng = theRNG();  
  34.     
  35.   float x1f,x2f,y1f,y2f;  
  36.   int x1, x2, y1, y2;  
  37.   //集合分类器基于n个基本分类器,每个分类器都是基于一个pixel comparisons(像素比较集)的;  
  38.   //pixel comparisons的产生方法:先用一个归一化的patch去离散化像素空间,产生所有可能的垂直和水平的pixel comparisons  
  39.   //然后我们把这些pixel comparisons随机分配给n个分类器,每个分类器得到完全不同的pixel comparisons(特征集合),  
  40.   //这样,所有分类器的特征组统一起来就可以覆盖整个patch了  
  41.     
  42.   //用随机数去填充每一个尺度扫描窗口的特征  
  43.   for (int i=0;i<totalFeatures;i++){  
  44.       x1f = (float)rng;  
  45.       y1f = (float)rng;  
  46.       x2f = (float)rng;  
  47.       y2f = (float)rng;  
  48.       for (int s=0; s<scales.size(); s++){  
  49.           x1 = x1f * scales[s].width;  
  50.           y1 = y1f * scales[s].height;  
  51.           x2 = x2f * scales[s].width;  
  52.           y2 = y2f * scales[s].height;  
  53.           //第s种尺度的第i个特征  两个随机分配的像素点坐标  
  54.           features[s][i] = Feature(x1, y1, x2, y2);  
  55.       }  
  56.   }  
  57.   //Thresholds  
  58.   thrN = 0.5 * nstructs;  
  59.   
  60.   //Initialize Posteriors  初始化后验概率  
  61.   //后验概率指每一个分类器对传入的图像片进行像素对比,每一个像素对比得到0或者1,所有的特征13个comparison对比,  
  62.   //连成一个13位的二进制代码x,然后索引到一个记录了后验概率的数组P(y|x),y为0或者1(二分类),也就是出现x的  
  63.   //基础上,该图像片为y的概率是多少对n个基本分类器的后验概率做平均,大于0.5则判定其含有目标  
  64.   for (int i = 0; i<nstructs; i++) {  
  65.   //每一个每类器维护一个后验概率的分布,这个分布有2^d个条目(entries),这里d是像素比较pixel comparisons  
  66.   //的个数,这里是structSize,即13个comparison,所以会产生2^13即8,192个可能的code,每一个code对应一个后验概率  
  67.   //后验概率P(y|x)= #p/(#p+#n) ,#p和#n分别是正和负图像片的数目,也就是下面的pCounter和nCounter  
  68.   //初始化时,每个后验概率都得初始化为0;运行时候以下面方式更新:已知类别标签的样本(训练样本)通过n个分类器  
  69.   //进行分类,如果分类结果错误,那么响应的#p和#n就会更新,这样P(y|x)也相应更新了  
  70.       posteriors.push_back(vector<float>(pow(2.0,structSize), 0));  
  71.       pCounter.push_back(vector<int>(pow(2.0,structSize), 0));  
  72.       nCounter.push_back(vector<int>(pow(2.0,structSize), 0));  
  73.   }  
  74. }  
  75.   
  76. //该函数得到输入的image的用于树的节点,也就是特征组的特征(13位的二进制代码)  
  77. void FerNNClassifier::getFeatures(const cv::Mat& image, const int& scale_idx, vector<int>& fern){  
  78.   int leaf;  //叶子  树的最终节点  
  79.   //每一个每类器维护一个后验概率的分布,这个分布有2^d个条目(entries),这里d是像素比较pixel comparisons  
  80.   //的个数,这里是structSize,即13个comparison,所以会产生2^13即8,192个可能的code,每一个code对应一个后验概率  
  81.   for (int t=0; t<nstructs; t++){  //nstructs 表示树的个数 10  
  82.       leaf=0;  
  83.       for (int f=0; f<structSize; f++){  //表示每棵树特征的个数 13  
  84.         //struct Feature 特征结构体有一个运算符重载 bool operator ()(const cv::Mat& patch) const  
  85.         //返回的patch图像片在(y1,x1)和(y2, x2)点的像素比较值,返回0或者1  
  86.         //然后leaf就记录了这13位的二进制代码,作为特征  
  87.           leaf = (leaf << 1) + features[scale_idx][t*nstructs+f](image);  
  88.       }  
  89.       fern[t] = leaf;   
  90.   }  
  91. }  
  92.   
  93. float FerNNClassifier::measure_forest(vector<int> fern) {  
  94.   float votes = 0;  
  95.   for (int i = 0; i < nstructs; i++) {  
  96.      // 后验概率posteriors[i][idx] = ((float)(pCounter[i][idx]))/(pCounter[i][idx] + nCounter[i][idx]);  
  97.       votes += posteriors[i][fern[i]];   //每棵树的每个特征值对应的后验概率累加值 作投票值??  
  98.   }  
  99.   return votes;  
  100. }  
  101.   
  102. //更新正负样本数,同时更新后验概率  
  103. void FerNNClassifier::update(const vector<int>& fern, int C, int N) {  
  104.   int idx;  
  105.   for (int i = 0; i < nstructs; i++) {  
  106.       idx = fern[i];  
  107.       (C==1) ? pCounter[i][idx] += N : nCounter[i][idx] += N;  
  108.       if (pCounter[i][idx]==0) {  
  109.           posteriors[i][idx] = 0;  
  110.       } else {  
  111.           posteriors[i][idx] = ((float)(pCounter[i][idx]))/(pCounter[i][idx] + nCounter[i][idx]);  
  112.       }  
  113.   }  
  114. }  
  115.   
  116. //训练集合分类器(n个基本分类器集合)  
  117. void FerNNClassifier::trainF(const vector<std::pair<vector<int>,int> >& ferns,int resample){  
  118.   // Conf = function(2,X,Y,Margin,Bootstrap,Idx)  
  119.   //                 0 1 2 3      4         5  
  120.   //  double *X     = mxGetPr(prhs[1]); -> ferns[i].first  
  121.   //  int numX      = mxGetN(prhs[1]);  -> ferns.size()  
  122.   //  double *Y     = mxGetPr(prhs[2]); ->ferns[i].second  
  123.   //  double thrP   = *mxGetPr(prhs[3]) * nTREES; ->threshold*nstructs  
  124.   //  int bootstrap = (int) *mxGetPr(prhs[4]); ->resample  
  125.     
  126.   //thr_fern: 0.6 thrP定义为Positive thershold  
  127.   thrP = thr_fern * nstructs;                                    // int step = numX / 10;  
  128.   //for (int j = 0; j < resample; j++) {                      // for (int j = 0; j < bootstrap; j++) {  
  129.       for (int i = 0; i < ferns.size(); i++){               //   for (int i = 0; i < step; i++) {  
  130.                                                             //     for (int k = 0; k < 10; k++) {  
  131.                                                             //       int I = k*step + i;//box index  
  132.                                                             //       double *x = X+nTREES*I; //tree index  
  133.           if(ferns[i].second==1){    //为1表示正样本        //       if (Y[I] == 1) {  
  134.               //measure_forest函数返回所有树的所有特征值对应的后验概率累加值  
  135.               //该累加值如果小于正样本阈值,也就是是输入的是正样本,却被分类成负样本了  
  136.               //出现分类错误,所以就把该样本添加到正样本库,同时更新后验概率  
  137.               if(measure_forest(ferns[i].first) <= thrP)      //         if (measure_forest(x) <= thrP)  
  138.               更新正样本数,同时更新后验概率  
  139.                 update(ferns[i].first, 1, 1);                 //             update(x,1,1);  
  140.           }else{                                            //        }else{  
  141.               if (measure_forest(ferns[i].first) >= thrN)   //         if (measure_forest(x) >= thrN)  
  142.                 update(ferns[i].first, 0, 1);                 //             update(x,0,1);  
  143.           }  
  144.       }  
  145.   //}  
  146. }  
  147.   
  148. //训练最近邻分类器  
  149. void FerNNClassifier::trainNN(const vector<cv::Mat>& nn_examples){  
  150.   float conf, dummy;  
  151.   vector<int> y(nn_examples.size(),0); //vector<T> v3(n, i); v3包含n个值为i的元素。y数组元素初始化为0  
  152.   y[0]=1;  //上面说到调用trainNN这个函数传入的nn_data样本集,只有一个pEx,在nn_data[0]  
  153.   vector<int> isin;  
  154.   for (int i=0; i<nn_examples.size(); i++){                          //  For each example  
  155.       //计算输入图像片与在线模型之间的相关相似度conf  
  156.       NNConf(nn_examples[i], isin, conf, dummy);                      //  Measure Relative similarity  
  157.       //thr_nn: 0.65 阈值  
  158.       //标签是正样本,如果相关相似度小于0.65 ,则认为其不含有前景目标,也就是分类错误了;这时候就把它加到正样本库  
  159.       if (y[i]==1 && conf <= thr_nn){                                //    if y(i) == 1 && conf1 <= tld.model.thr_nn % 0.65  
  160.           if (isin[1]<0){                                          //      if isnan(isin(2))  
  161.               pEx = vector<Mat>(1,nn_examples[i]);                 //        tld.pex = x(:,i);  
  162.               continue;                                            //        continue;  
  163.           }                                                        //      end  
  164.           //pEx.insert(pEx.begin()+isin[1],nn_examples[i]);        //      tld.pex = [tld.pex(:,1:isin(2)) x(:,i) tld.pex(:,isin(2)+1:end)]; % add to model  
  165.           pEx.push_back(nn_examples[i]);  
  166.       }                                                            //    end  
  167.       if(y[i]==0 && conf>0.5)                                      //  if y(i) == 0 && conf1 > 0.5  
  168.         nEx.push_back(nn_examples[i]);                             //    tld.nex = [tld.nex x(:,i)];  
  169.   
  170.   }                                                                 //  end  
  171.   acum++;  
  172.   printf("%d. Trained NN examples: %d positive %d negative\n",acum,(int)pEx.size(),(int)nEx.size());  
  173. }                                                                  //  end  
  174.   
  175.   /*Inputs: 
  176.    * -NN Patch 
  177.    * Outputs: 
  178.    * -Relative Similarity (rsconf)相关相似度, Conservative Similarity (csconf)保守相似度, 
  179.    * In pos. set|Id pos set|In neg. set (isin) 
  180.    */  
  181. void FerNNClassifier::NNConf(const Mat& example, vector<int>& isin,float& rsconf,float& csconf){  
  182.   isin=vector<int>(3,-1);  //vector<T> v3(n, i); v3包含n个值为i的元素。 三个元素都是-1  
  183.   if (pEx.empty()){ //if isempty(tld.pex) % IF positive examples in the model are not defined THEN everything is negative  
  184.       rsconf = 0; //    conf1 = zeros(1,size(x,2));  
  185.       csconf=0;  
  186.       return;  
  187.   }  
  188.   if (nEx.empty()){ //if isempty(tld.nex) % IF negative examples in the model are not defined THEN everything is positive  
  189.       rsconf = 1;   //    conf1 = ones(1,size(x,2));  
  190.       csconf=1;  
  191.       return;  
  192.   }  
  193.   Mat ncc(1,1,CV_32F);  
  194.   float nccP, csmaxP, maxP=0;  
  195.   bool anyP=false;  
  196.   int maxPidx, validatedPart = ceil(pEx.size()*valid);  //ceil返回大于或者等于指定表达式的最小整数  
  197.   float nccN, maxN=0;  
  198.   bool anyN=false;  
  199.   //比较图像片p到在线模型M的距离(相似度),计算正样本最近邻相似度,也就是将输入的图像片与  
  200.   //在线模型中所有的图像片进行匹配,找出最相似的那个图像片,也就是相似度的最大值  
  201.   for (int i=0;i<pEx.size();i++){  
  202.       matchTemplate(pEx[i], example, ncc, CV_TM_CCORR_NORMED);      // measure NCC to positive examples  
  203.       nccP=(((float*)ncc.data)[0]+1)*0.5;  //计算匹配相似度  
  204.       if (nccP>ncc_thesame)  //ncc_thesame: 0.95  
  205.         anyP=true;  
  206.       if(nccP > maxP){  
  207.           maxP=nccP;    //记录最大的相似度以及对应的图像片index索引值  
  208.           maxPidx = i;  
  209.           if(i<validatedPart)  
  210.             csmaxP=maxP;  
  211.       }  
  212.   }  
  213.   //计算负样本最近邻相似度  
  214.   for (int i=0;i<nEx.size();i++){  
  215.       matchTemplate(nEx[i],example,ncc,CV_TM_CCORR_NORMED);     //measure NCC to negative examples  
  216.       nccN=(((float*)ncc.data)[0]+1)*0.5;  
  217.       if (nccN>ncc_thesame)  
  218.         anyN=true;  
  219.       if(nccN > maxN)  
  220.         maxN=nccN;  
  221.   }  
  222.   //set isin  
  223.   //if he query patch is highly correlated with any positive patch in the model then it is considered to be one of them  
  224.   if (anyP) isin[0]=1;    
  225.   isin[1]=maxPidx;      //get the index of the maximall correlated positive patch  
  226.   //if  the query patch is highly correlated with any negative patch in the model then it is considered to be one of them  
  227.   if (anyN) isin[2]=1;   
  228.     
  229.   //Measure Relative Similarity  
  230.   //相关相似度 = 正样本最近邻相似度 / (正样本最近邻相似度 + 负样本最近邻相似度)  
  231.   float dN=1-maxN;  
  232.   float dP=1-maxP;  
  233.   rsconf = (float)dN/(dN+dP);  
  234.     
  235.   //Measure Conservative Similarity  
  236.   dP = 1 - csmaxP;  
  237.   csconf =(float)dN / (dN + dP);  
  238. }  
  239.   
  240. void FerNNClassifier::evaluateTh(const vector<pair<vector<int>,int> >& nXT, const vector<cv::Mat>& nExT){  
  241.   float fconf;  
  242.   for (int i=0;i<nXT.size();i++){  
  243.   //所有基本分类器的后验概率的平均值如果大于thr_fern,则认为含有前景目标  
  244.   //measure_forest返回的是所有后验概率的累加和,nstructs 为树的个数,也就是基本分类器的数目 ??  
  245.     fconf = (float) measure_forest(nXT[i].first)/nstructs;  
  246.     if (fconf>thr_fern)  //thr_fern: 0.6 thrP定义为Positive thershold  
  247.       thr_fern = fconf;  //取这个平均值作为 该集合分类器的 新的阈值,这就是训练??  
  248.   }  
  249.     
  250.   vector <int> isin;  
  251.   float conf, dummy;  
  252.   for (int i=0; i<nExT.size(); i++){  
  253.       NNConf(nExT[i], isin, conf, dummy);  
  254.       if (conf > thr_nn)  
  255.         thr_nn = conf; //取这个最大相关相似度作为 该最近邻分类器的 新的阈值,这就是训练??  
  256.   }  
  257.     
  258.   if (thr_nn > thr_nn_valid)  //thr_nn_valid: 0.7  
  259.     thr_nn_valid = thr_nn;  
  260. }  
  261.   
  262. //把正样本库(在线模型)包含的所有正样本显示在窗口上  
  263. void FerNNClassifier::show(){  
  264.   Mat examples((int)pEx.size()*pEx[0].rows, pEx[0].cols, CV_8U);  
  265.   double minval;  
  266.   Mat ex(pEx[0].rows, pEx[0].cols, pEx[0].type());  
  267.   for (int i=0;i<pEx.size();i++){  
  268.     //minMaxLoc寻找矩阵(一维数组当作向量,用Mat定义)中最小值和最大值的位置.   
  269.     minMaxLoc(pEx[i], &minval); //寻找pEx[i]的最小值  
  270.     pEx[i].copyTo(ex);  
  271.     ex = ex - minval;  //把像素亮度最小的像素重设为0,其他像素按此重设  
  272.     //Mat Mat::rowRange(int startrow, int endrow) const 为指定的行span创建一个新的矩阵头。  
  273.     //Mat Mat::rowRange(const Range& r) const   //Range 结构包含着起始和终止的索引值。  
  274.     Mat tmp = examples.rowRange(Range(i*pEx[i].rows, (i+1)*pEx[i].rows));  
  275.     ex.convertTo(tmp, CV_8U);  
  276.   }  
  277.   imshow("Examples", examples);  
  278. }  

  • 0
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
1、资源项目源码均已通过严格测试验证,保证能够正常运行; 2、项目问题、技术讨论,可以给博主私信或留言,博主看到后会第一时间与您进行沟通; 3、本项目比较适合计算机领域相关的毕业设计课题、课程作业等使用,尤其对于人工智能、计算机科学与技术等相关专业,更为适合; 4、下载使用后,可先查看README.md文件(如有),本项目仅用作交流学习参考,请切勿用于商业用途。1、资源项目源码均已通过严格测试验证,保证能够正常运行; 2、项目问题、技术讨论,可以给博主私信或留言,博主看到后会第一时间与您进行沟通; 3、本项目比较适合计算机领域相关的毕业设计课题、课程作业等使用,尤其对于人工智能、计算机科学与技术等相关专业,更为适合; 4、下载使用后,可先查看README.md文件(如有),本项目仅用作交流学习参考,请切勿用于商业用途。1、资源项目源码均已通过严格测试验证,保证能够正常运行; 2、项目问题、技术讨论,可以给博主私信或留言,博主看到后会第一时间与您进行沟通; 3、本项目比较适合计算机领域相关的毕业设计课题、课程作业等使用,尤其对于人工智能、计算机科学与技术等相关专业,更为适合; 4、下载使用后,可先查看README.md文件(如有),本项目仅用作交流学习参考,请切勿用于商业用途。1、资源项目源码均已通过严格测试验证,保证能够正常运行; 2、项目问题、技术讨论,可以给博主私信或留言,博主看到后会第一时间与您进行沟通; 3、本项目比较适合计算机领域相关的毕业设计课题、课程作业等使用,尤其对于人工智能、计算机科学与技术等相关专业,更为适合; 4、下载使用后,可先查看README.md文件(如有),本项目仅用作交流学习参考,请切勿用于商业用途。1、资源项目源码均已通过严格测试验证,保证能够正常运行; 2、项目问题、技术讨论,可以给博主私信或留言,博主看到后会第一时间与您进行沟通; 3、本项目比较适合计算机领域相关的毕业设计课题、课程作业等使用,尤其对于人工智能、计算机科学与技术等相关专业,更为适合; 4、下载使用后,可先查看README.md文件(如有),本项目仅用作交流学习参考,请切勿用于商业用途。1、资源项目源码均已通过严格测试验证,保证能够正常运行; 2、项目问题、技术讨论,可以给博主私信或留言,博主看到后会第一时间与您进行沟通; 3、本项目比较适合计算机领域相关的毕业设计课题、课程作业等使用,尤其对于人工智能、计算机科学与技术等相关专业,更为适合; 4、下载使用后,可先查看README.md文件(如有),本项目仅用作交流学习参考,请切勿用于商业用途。1、资源项目源码均已通过严格测试验证,保证能够正常运行; 2、项目问题、技术讨论,可以给博主私信或留言,博主看到后会第一时间与您进行沟通; 3、本项目比较适合计算机领域相关的毕业设计课题、课程作业等使用,尤其对于人工智能、计算机科学与技术等相关专业,更为适合; 4、下载使用后,可先查看README.md文件(如有),本项目仅用作交流学习参考,请切勿用于商业用途。
1、资源项目源码均已通过严格测试验证,保证能够正常运行; 2、项目问题、技术讨论,可以给博主私信或留言,博主看到后会第一时间与您进行沟通; 3、本项目比较适合计算机领域相关的毕业设计课题、课程作业等使用,尤其对于人工智能、计算机科学与技术等相关专业,更为适合; 4、下载使用后,可先查看README.md文件(如有),本项目仅用作交流学习参考,请切勿用于商业用途。1、资源项目源码均已通过严格测试验证,保证能够正常运行; 2、项目问题、技术讨论,可以给博主私信或留言,博主看到后会第一时间与您进行沟通; 3、本项目比较适合计算机领域相关的毕业设计课题、课程作业等使用,尤其对于人工智能、计算机科学与技术等相关专业,更为适合; 4、下载使用后,可先查看README.md文件(如有),本项目仅用作交流学习参考,请切勿用于商业用途。1、资源项目源码均已通过严格测试验证,保证能够正常运行; 2、项目问题、技术讨论,可以给博主私信或留言,博主看到后会第一时间与您进行沟通; 3、本项目比较适合计算机领域相关的毕业设计课题、课程作业等使用,尤其对于人工智能、计算机科学与技术等相关专业,更为适合; 4、下载使用后,可先查看README.md文件(如有),本项目仅用作交流学习参考,请切勿用于商业用途。1、资源项目源码均已通过严格测试验证,保证能够正常运行; 2、项目问题、技术讨论,可以给博主私信或留言,博主看到后会第一时间与您进行沟通; 3、本项目比较适合计算机领域相关的毕业设计课题、课程作业等使用,尤其对于人工智能、计算机科学与技术等相关专业,更为适合; 4、下载使用后,可先查看README.md文件(如有),本项目仅用作交流学习参考,请切勿用于商业用途。1、资源项目源码均已通过严格测试验证,保证能够正常运行; 2、项目问题、技术讨论,可以给博主私信或留言,博主看到后会第一时间与您进行沟通; 3、本项目比较适合计算机领域相关的毕业设计课题、课程作业等使用,尤其对于人工智能、计算机科学与技术等相关专业,更为适合; 4、下载使用后,可先查看README.md文件(如有),本项目仅用作交流学习参考,请切勿用于商业用途。1、资源项目源码均已通过严格测试验证,保证能够正常运行; 2、项目问题、技术讨论,可以给博主私信或留言,博主看到后会第一时间与您进行沟通; 3、本项目比较适合计算机领域相关的毕业设计课题、课程作业等使用,尤其对于人工智能、计算机科学与技术等相关专业,更为适合; 4、下载使用后,可先查看README.md文件(如有),本项目仅用作交流学习参考,请切勿用于商业用途。1、资源项目源码均已通过严格测试验证,保证能够正常运行; 2、项目问题、技术讨论,可以给博主私信或留言,博主看到后会第一时间与您进行沟通; 3、本项目比较适合计算机领域相关的毕业设计课题、课程作业等使用,尤其对于人工智能、计算机科学与技术等相关专业,更为适合; 4、下载使用后,可先查看README.md文件(如有),本项目仅用作交流学习参考,请切勿用于商业用途。
1、资源项目源码均已通过严格测试验证,保证能够正常运行; 2、项目问题、技术讨论,可以给博主私信或留言,博主看到后会第一时间与您进行沟通; 3、本项目比较适合计算机领域相关的毕业设计课题、课程作业等使用,尤其对于人工智能、计算机科学与技术等相关专业,更为适合; 4、下载使用后,可先查看README.md文件(如有),本项目仅用作交流学习参考,请切勿用于商业用途。1、资源项目源码均已通过严格测试验证,保证能够正常运行; 2、项目问题、技术讨论,可以给博主私信或留言,博主看到后会第一时间与您进行沟通; 3、本项目比较适合计算机领域相关的毕业设计课题、课程作业等使用,尤其对于人工智能、计算机科学与技术等相关专业,更为适合; 4、下载使用后,可先查看README.md文件(如有),本项目仅用作交流学习参考,请切勿用于商业用途。1、资源项目源码均已通过严格测试验证,保证能够正常运行; 2、项目问题、技术讨论,可以给博主私信或留言,博主看到后会第一时间与您进行沟通; 3、本项目比较适合计算机领域相关的毕业设计课题、课程作业等使用,尤其对于人工智能、计算机科学与技术等相关专业,更为适合; 4、下载使用后,可先查看README.md文件(如有),本项目仅用作交流学习参考,请切勿用于商业用途。1、资源项目源码均已通过严格测试验证,保证能够正常运行; 2、项目问题、技术讨论,可以给博主私信或留言,博主看到后会第一时间与您进行沟通; 3、本项目比较适合计算机领域相关的毕业设计课题、课程作业等使用,尤其对于人工智能、计算机科学与技术等相关专业,更为适合; 4、下载使用后,可先查看README.md文件(如有),本项目仅用作交流学习参考,请切勿用于商业用途。1、资源项目源码均已通过严格测试验证,保证能够正常运行; 2、项目问题、技术讨论,可以给博主私信或留言,博主看到后会第一时间与您进行沟通; 3、本项目比较适合计算机领域相关的毕业设计课题、课程作业等使用,尤其对于人工智能、计算机科学与技术等相关专业,更为适合; 4、下载使用后,可先查看README.md文件(如有),本项目仅用作交流学习参考,请切勿用于商业用途。1、资源项目源码均已通过严格测试验证,保证能够正常运行; 2、项目问题、技术讨论,可以给博主私信或留言,博主看到后会第一时间与您进行沟通; 3、本项目比较适合计算机领域相关的毕业设计课题、课程作业等使用,尤其对于人工智能、计算机科学与技术等相关专业,更为适合; 4、下载使用后,可先查看README.md文件(如有),本项目仅用作交流学习参考,请切勿用于商业用途。1、资源项目源码均已通过严格测试验证,保证能够正常运行; 2、项目问题、技术讨论,可以给博主私信或留言,博主看到后会第一时间与您进行沟通; 3、本项目比较适合计算机领域相关的毕业设计课题、课程作业等使用,尤其对于人工智能、计算机科学与技术等相关专业,更为适合; 4、下载使用后,可先查看README.md文件(如有),本项目仅用作交流学习参考,请切勿用于商业用途。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值