写代码时候需要用到kron这个函数,matlab中有,但搜了一圈竟然OpenCV没有,而且复现的很一般,因此,我尽可能用最标准的方法,也就是兼容各种数据格式,实现了这个函数,复现之后,在我的研究中测试,无误。
下面就放上这个函数的声明,m1,m2
为输入的二维矩阵, mkron
为输出的对应矩阵。
void kron(cv::InputArray m1, cv::InputArray m2, cv::OutputArray mkron);
在matlab中的实例
m1 = [1,2,3;2,3,4];
m2 = [2,3,4;3,4,5];
mkron = kron(m1,m2);
%% 输出:
mkron =
2 3 4 4 6 8 6 9 12
3 4 5 6 8 10 9 12 15
4 6 8 6 9 12 8 12 16
6 8 10 9 12 15 12 16 20
在C++中用法:
int m1[] = { 1,2,3,2,3,4 };
int m2[] = { 2,3,4,3,4,5 };
cv::Mat cvm1(2, 3, CV_32SC1, m1), cvm2(2, 3, CV_32SC1, m2), mkron;
kron(cvm1, cvm2, mkron);
std::cout << mkron << std::endl;
// 输出
// [2, 3, 4, 4, 6, 8, 6, 9, 12;
// 3, 4, 5, 6, 8, 10, 9, 12, 15;
// 4, 6, 8, 6, 9, 12, 8, 12, 16;
// 6, 8, 10, 9, 12, 15, 12, 16, 20]
下面放上对应的源代码,注意,m1和m2数据类型要一致,且仅支持CV_64F, CV_32F, CV_32S, CV_16S这几种数据类型的计算。
void kron(cv::InputArray m1, cv::InputArray m2, cv::OutputArray mkron)
{
int ma = m1.rows(), na = m1.cols(), mb = m2.rows(), nb = m2.cols();
CV_Assert(m1.type() == m2.type() &&
(m1.type() == CV_32F || m1.type() == CV_64F || m1.type() == CV_32S || m1.type() == CV_16S));
cv::Mat K(ma*mb, na*nb, m1.type());
cv::Mat tm1 = m1.getMat();
cv::Mat tm2 = m2.getMat();
cv::Rect roi;
switch (m1.type())
{
case CV_64F:
{
double *_data = (double*)tm1.data;
for (int i = 0; i < ma; i++)
{
for (int j = 0; j < na; j++)
{
roi.y = i * mb, roi.x = j * nb;
roi.width = nb, roi.height = mb;
tm2.convertTo(K(roi), m1.type(), _data[i*na + j]);
}
}
K.copyTo(mkron);
break;
}
case CV_32F:
{
float *_data = (float*)tm1.data;
for (int i = 0; i < ma; i++)
{
for (int j = 0; j < na; j++)
{
roi.y = i * mb, roi.x = j * nb;
roi.width = nb, roi.height = mb;
tm2.convertTo(K(roi), m1.type(), _data[i*na + j]);
}
}
K.copyTo(mkron);
break;
}
case CV_32S:
{
int *_data = (int*)tm1.data;
for (int i = 0; i < ma; i++)
{
for (int j = 0; j < na; j++)
{
roi.y = i * mb, roi.x = j * nb;
roi.width = nb, roi.height = mb;
tm2.convertTo(K(roi), m1.type(), _data[i*na + j]);
}
}
K.copyTo(mkron);
break;
}
case CV_16S:
{
short *_data = (short*)tm1.data;
for (int i = 0; i < ma; i++)
{
for (int j = 0; j < na; j++)
{
roi.y = i * mb, roi.x = j * nb;
roi.width = nb, roi.height = mb;
tm2.convertTo(K(roi), m1.type(), _data[i*na + j]);
}
}
K.copyTo(mkron);
break;
}
default:
break;
}
}
代码的后面为了方便管理数据类型,利用switch对每个数据类型进行了判断,因此看起来冗余了许多,反正使用的时候直接粘贴就ok。