matlab c++混合编程
- 目的:尝试c++与matlab混合编程,用c++编写关键代码,matlab编译后可以调用,主要用在瓶颈以提高速度。
- 功能:实现两个矩阵间,成对样本点的欧氏距离计算。
- 注:在matlab中已经有pdist(),pdist2()函数可以快速的实现相应的功能,这里只做练习。
/************************************************************************/
/* mex函数:实现c++与matlab混合编程,程序实现成对样本点的欧式距离的平方的计算 */
/* 注:为了快速计算,求根操作没有实现,需要在该函数调用后,在MATLAB中进行求根运算*/
/************************************************************************/
#include "mex.h"
#define Matrix1(row,col) Matrix1[col*n+row]
#define Matrix2(row,col) Matrix2[col*n+row]
#define Matrix_out(row,col) Matrix_out[col*m1+row]
void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
{
//参数检查
if (nrhs != 2)
{
mexErrMsgTxt("输入参数只能是两个矩阵");
}
if (nlhs > 1)
{
mexErrMsgTxt("输出参数只能为一个矩阵");
}
#define INT_1 prhs[0]
#define INT_2 prhs[1]
#define OUT plhs[0]
double *Matrix1, *Matrix2,*Matrix_out;
int n, m1, m2,ri,ci,rni;
m1 = mxGetN(INT_1);
m2 = mxGetN(INT_2);
n = mxGetM(INT_1);
if (mxGetM(INT_2) != n)
{
mexErrMsgTxt("两个输入矩阵的行数必须相同");
}
Matrix1 = mxGetPr(INT_1);
Matrix2 = mxGetPr(INT_2);
//OUT = mxCreateDoubleMatrix(m1, m2, mxREAL);
OUT = mxCreateUninitNumericMatrix(m1, m2, mxDOUBLE_CLASS, mxREAL);//创建未初始化的矩阵,可以加速
Matrix_out = mxGetPr(OUT);
//距离计算
for (ri=0;ri<m1;ri++)
{
for (ci=0;ci<m2;ci++)
{
double temp = 0;
for (rni = 0; rni < n; rni++)
{
//这里不要用pow(),很慢
double sub = (Matrix1(rni, ri) - Matrix2(rni, ci));
temp += sub * sub;
}
//这里没有求根
Matrix_out(ri, ci) = temp;
}
}
}
测试:
a=rand(1024,4000);
b=rand(5000,1000);
tic
aa=pdist2(a',a','euclidean');%自带的对距离计算函数,核心调用pdist2mex
toc
tic
bb=pdistmex(a,'euc',[],[],[]);bb=squareform(bb);
%自带的对距离计算函数,核心调用pdist2mex
toc
% tic
% c=DistMatrix(a',b');
% toc
% tic
% cc=mypdist(a,b);
% toc
% tic
% ccc=disPoint(a',b');
% toc
tic
aa=pdist2(a',b','euclidean');%自带的对距离计算函数,核心调用pdist2mex
toc
% aa=pdist(a');%aa是一个向量,包含下三角矩阵的值
% aa=squareform(aa);
tic
bb=pdist2mex(a,b,'euc',[],[],[]);%自带的mex函数速度更快
toc
tic
cc=mysecondpdist(a,b);cc=sqrt(cc);%自己的,速度慢,求根运算不在其中
toc
效果:速度不如matlab自带的pdist2()函数,大概是其1/5,另外,pdist2()函数的核心是调用经编译的pdist2mex.mexw64文件,路径:D:\Program Files\MATLAB\R2017a\toolbox\stats\stats\private。速度最快。
参考资料:
https://stackoverflow.com/questions/19253405/pairwise-distance-calculation-in-c 优化代码提高速度
http://blog.sciencenet.cn/blog-531885-589056.html matlab中pdist()和pdist2()函数的使用方法
http://www.getreuer.info/ Pascal Getreuer的两本书 Writing Fast MATLAB Code和Writing MATLAB C/MEX Code
http://blog.sciencenet.cn/blog-43777-320103.html mexFunction函数相关