因为个人项目原因,我曾将参考OpenMax源码GitHub - abhijitbendale/OSDN: Code and data for the research paper “Towards Open Set Deep Networks” A Bendale, T Boult, CVPR 2016将其转换到MATLAB使用。
OpenMax 使用极值理论实现对开放集的筛选,在计算得分时需要用libMR 包,用于计算 weibull 概率分布。因此,训练过程中会需要fit_high() 函数用于提取右端极大值数据进行拟合,保存每一个类别拟合的MetaRecognition 对象。测试时调用预测类别 MetaRecognition 对象的CDF() 函数计算得分,获得预测样本的得分。
为了能够在 MATLAB 中实现调用 libMR ,我做出了如下操作:
-
定义了一个句柄类 MetaRecognitionHandle,里面封装一些与 MetaRecognition
相关的操作,主要是封装了需要使用到的函数,使得这些操作可以在 MATLAB 中以面向对象的方式调用。 -
针对要用到的函数定义对应的 mexfunction文件,实现对 fit_high() 和CDF() 函数的调用。
-
MATLAB 命令行一次运行
mex XXX.cpp
。
为了在 MATLAB 中保存 MetaRecognition 对象,我在运行 MATLAB 代码时为每一个 MetaRecognition 对象开辟一个相应的内存空间,用于调用相应类别拟合成的对象所对应的函数。
classdef MetaRecognitionHandle < handle
properties (Access = private)
ObjectHandle; % 存储C++对象指针
end
methods
function obj = MetaRecognitionHandle()
obj.ObjectHandle = mexCreateMetaRecognition(); % 创建MetaRecognition对象的MEX函数
end
function delete(obj)
mexDeleteMetaRecognition(obj.ObjectHandle); % 销毁MetaRecognition对象的MEX函数
end
function c=fitHigh(obj, tailToFit, tailSize)
c=mexFitHigh(obj.ObjectHandle, tailToFit, tailSize); % 调用fit_high的MEX函数
end
function score = wScore(obj, channelDistance)
score = mexWScore(obj.ObjectHandle, channelDistance); % 调用w_score的MEX函数
end
end
end
fit_high() 函数对应 mexFitHigh.cpp
#include "mex.h"
#include "MetaRecognition.h"
void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[]) {
if (nrhs != 3) {
mexErrMsgIdAndTxt("MetaRecognition:nlhs", "3 is needed.");
}
MetaRecognition* mr = reinterpret_cast<MetaRecognition*>(*((uint64_t*)mxGetData(prhs[0])));
double* tailToFit = mxGetPr(prhs[1]);
size_t tailSize = static_cast<size_t>(mxGetScalar(prhs[2]));
plhs[0] = mxCreateDoubleScalar(mr->FitHigh(tailToFit, tailSize, tailSize));
}
CDF() 函数对应 mexWScore.cpp
#include "mex.h"
#include "MetaRecognition.h"
void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[]) {
if (nrhs != 2) {
mexErrMsgIdAndTxt("MetaRecognition:wScore:nrhs", "2 is need");
}
MetaRecognition* mr = reinterpret_cast<MetaRecognition*>(*((uint64_t*)mxGetData(prhs[0])));
double channelDistance = mxGetScalar(prhs[1]);
double score = mr->CDF(channelDistance);
plhs[0] = mxCreateDoubleScalar(score);
}
mexCreateMetaRecognition.cpp
#include "mex.h"
#include "MetaRecognition.h"
void mexFunction(int nlhs, mxArray* plhs[], int nrhs, const mxArray* prhs[]) {
MetaRecognition* mr = new MetaRecognition();
plhs[0] = mxCreateNumericMatrix(1, 1, mxUINT64_CLASS, mxREAL);
*((uint64_t*)mxGetData(plhs[0])) = reinterpret_cast<uint64_t>(mr);
}
mexDeleteMetaRecognition.cpp
#include "mex.h"
#include "MetaRecognition.h"
void mexFunction(int nlhs, mxArray* plhs[], int nrhs, const mxArray* prhs[]) {
uint64_t handle = *((uint64_t*)mxGetData(prhs[0]));
MetaRecognition* mr = reinterpret_cast<MetaRecognition*>(handle);
delete mr;
}
调用方法
tailToFit = sort(distance_scores, 'descend');
tailToFit = tailToFit(1:tailsize); % 取最大的tailsize个值
mr = MetaRecognitionHandle();
mr.fitHigh(tailToFit, tailsize);
wScore = categoryWeibull{3}.wScore(channelDistance);