矩阵就不用再解释了,写成泛型主要是为了几个方便:
1、方便在栈上分配空间。由于维度在编译期已知,所以可以做到在栈上分配空间。当然如果这个对象是new出来的,自然是在堆上分配,这里说的是在栈上分配这个对象时,矩阵元素所占用的空间也在栈上分配。
2、方便在编译期检查非法的矩阵运算。C++模板的强大推导能力可以在编译期推导出结果矩阵的维度。
3、泛型类在方法内联上具有优势。
这个矩阵类为了能够直接从数组赋值,使用了一个ArrayPorxy类(可参考《Imperfect C++》)。
代码如下:
template
<
class
T,
int
D1,
int
D2
=
1
>
class ArrayProxy
{
T * data;
public :
ArrayProxy(T ( & value)[D1][D2])
: data( & value[ 0 ][ 0 ])
{
}
ArrayProxy(T ( & value)[D1 * D2])
: data(value)
{
}
T * getData() const
{
return data;
}
};
class ArrayProxy
{
T * data;
public :
ArrayProxy(T ( & value)[D1][D2])
: data( & value[ 0 ][ 0 ])
{
}
ArrayProxy(T ( & value)[D1 * D2])
: data(value)
{
}
T * getData() const
{
return data;
}
};
这个只是简单的实现。
因为我基本上不使用这个矩阵类,所以只完成几个简单功能:
1、从数组赋值:
int a[][3] = {{1,2,3}, {4,5,6}};
Matrix<int, 2, 3> m1(a);
或
int a[] = {1,2,3, 4,5,6};
Matrix<int, 2, 3> m1(a);
Matrix<int, 3, 2> m2(a);
Matrix<int, 6, 1> m3(a);
Matrix<int, 1, 6> m4(a);
2、矩阵乘法:
Matrix<int, 2, 3> m1;
Matrix<int, 2, 4> m2;
// m1 * m2 <== 编译错误,维度不匹配
Matrix<int, 3, 5> m3;
Matrix<int, 2, 5> m4 = m1 * m3; // <== 合法
// m3 * m1; // <== 编译错误,维度不匹配
源码如下:
template
<
class
T,
int
R,
int
C
>
class Matrix
{
T matrix[R][C];
public :
// Big three
Matrix( void )
{
memset(matrix, 0 , sizeof (matrix));
}
Matrix( const Matrix & rhs)
{
memcpy(matrix, rhs.matrix, sizeof (matrix));
}
Matrix & operator = ( const Matrix & rhs)
{
memcpy(matrix, rhs.matrix, sizeof (matrix));
return * this ;
}
public :
Matrix( const ArrayProxy < T,R,C >& arr)
{
memcpy(matrix, arr.getData(), sizeof (matrix));
}
~ Matrix( void )
{
}
public :
T get ( int r, int c) const
{
assert(c < C && c >= 0 && r < R && r >= 0 );
return matrix[r][c];
}
void set ( int r, int c, T v)
{
assert(c < C && c >= 0 && r < R && r >= 0 );
matrix[r][c] = v;
}
int getCols () const
{
return C;
}
int getRows () const
{
return R;
}
bool operator == ( const Matrix & rhs) const
{
return memcmp(matrix, rhs.matrix, sizeof (matrix)) == 0 ;
}
bool operator != ( const Matrix & rhs) const
{
return ! ( * this == rhs);
}
};
template < class T, int R, int C, int C1 >
Matrix < T,R,C1 > operator * ( const Matrix < T,R,C >& lhs, const Matrix < T,C,C1 >& rhs)
{
Matrix < T,R,C1 > result;
for ( int r = 0 ; r < R; ++ r)
{
for ( int c = 0 ; c < C1; ++ c)
{
int value = 0 ;
for ( int i = 0 ; i < C; ++ i)
{
value += lhs. get (r,i) * rhs. get (i,c);
}
result. set (r,c,value);
}
}
return result;
}
class Matrix
{
T matrix[R][C];
public :
// Big three
![dot.gif](/Images/dot.gif)
Matrix( void )
{
memset(matrix, 0 , sizeof (matrix));
}
Matrix( const Matrix & rhs)
{
memcpy(matrix, rhs.matrix, sizeof (matrix));
}
Matrix & operator = ( const Matrix & rhs)
{
memcpy(matrix, rhs.matrix, sizeof (matrix));
return * this ;
}
public :
Matrix( const ArrayProxy < T,R,C >& arr)
{
memcpy(matrix, arr.getData(), sizeof (matrix));
}
~ Matrix( void )
{
}
public :
T get ( int r, int c) const
{
assert(c < C && c >= 0 && r < R && r >= 0 );
return matrix[r][c];
}
void set ( int r, int c, T v)
{
assert(c < C && c >= 0 && r < R && r >= 0 );
matrix[r][c] = v;
}
int getCols () const
{
return C;
}
int getRows () const
{
return R;
}
bool operator == ( const Matrix & rhs) const
{
return memcmp(matrix, rhs.matrix, sizeof (matrix)) == 0 ;
}
bool operator != ( const Matrix & rhs) const
{
return ! ( * this == rhs);
}
};
template < class T, int R, int C, int C1 >
Matrix < T,R,C1 > operator * ( const Matrix < T,R,C >& lhs, const Matrix < T,C,C1 >& rhs)
{
Matrix < T,R,C1 > result;
for ( int r = 0 ; r < R; ++ r)
{
for ( int c = 0 ; c < C1; ++ c)
{
int value = 0 ;
for ( int i = 0 ; i < C; ++ i)
{
value += lhs. get (r,i) * rhs. get (i,c);
}
result. set (r,c,value);
}
}
return result;
}
测试代码:
int
main()
{
{
// 测试初始化
Matrix < int , 3 , 4 > m1;
Matrix < int , 3 , 4 > m2(m1);
Matrix < int , 3 , 4 > m3 = m1;
Matrix < int , 3 , 4 > m4;
m4 = m1;
for ( int i = 0 ; i < 3 ; i ++ )
for ( int j = 0 ; j < 4 ; j ++ )
{
assert (m1. get (i, j) == 0 );
assert (m2. get (i, j) == 0 );
assert (m3. get (i, j) == 0 );
assert (m4. get (i, j) == 0 );
}
int a[] = { 1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 , 10 , 11 , 12 };
Matrix < int , 3 , 4 > m5(a);
int b[ 3 ][ 4 ] = { { 1 , 2 , 3 , 4 },
{ 5 , 6 , 7 , 8 },
{ 9 , 10 , 11 , 12 }};
Matrix < int , 3 , 4 > m6(b);
Matrix < int , 3 , 4 > m7(m5);
Matrix < int , 3 , 4 > m8 = m5;
Matrix < int , 3 , 4 > m9;
m9 = m5;
for ( int i = 0 ; i < 3 ; i ++ )
for ( int j = 0 ; j < 4 ; j ++ )
{
assert (m5. get (i, j) == i * 4 + j + 1 );
assert (m6. get (i, j) == i * 4 + j + 1 );
assert (m7. get (i, j) == i * 4 + j + 1 );
assert (m8. get (i, j) == i * 4 + j + 1 );
assert (m9. get (i, j) == i * 4 + j + 1 );
}
// 维数不匹配,编译错误
// Matrix<int, 4, 5> m10 = m9;
int c[][ 2 ] = {{ 1 , 2 }, { 2 , 3 }};
// 数组大小不匹配,编译错误
// Matrix<int, 3, 4> m10(c);
int d[] = { 1 , 2 };
// 数组大小不匹配,编译错误
// Matrix<int, 3, 4> m11(d);
// 乘法维数不合适,无法相乘
// m1 * m2;
Matrix < int , 4 , 3 > m12;
// 匹配,可以相乘
Matrix < int , 3 , 3 > m13 = m1 * m12;
Matrix < int , 8 , 3 > m14;
// 无法相乘
// Matrix<int, 3, 3> m15 = m1 * m14;
// 可以相乘
Matrix < int , 8 , 4 > m15 = m14 * m1;
}
{
// 检查点乘
int a[ 2 ][ 5 ] = {{ 1 , 2 , 3 , 4 , 5 }, { 6 , 7 , 8 , 9 , 10 }};
Matrix < int , 2 , 5 > m1(a);
int b[ 5 ][ 3 ] = {{ 1 , 2 , 3 }, { 4 , 5 , 6 }, { 7 , 8 , 9 }, { 10 , 11 , 12 }, { 13 , 14 , 15 }};
Matrix < int , 5 , 3 > m2(b);
int c[ 2 ][ 3 ] = {{ 135 , 150 , 165 }, { 310 , 350 , 390 }};
Matrix < int , 2 , 3 > m3(c);
Matrix < int , 2 , 3 > m4 = m1 * m2;
assert(m4 == m3);
cout << m4. get ( 0 , 0 ) << endl;
}
return 0 ;
}
{
{
// 测试初始化
Matrix < int , 3 , 4 > m1;
Matrix < int , 3 , 4 > m2(m1);
Matrix < int , 3 , 4 > m3 = m1;
Matrix < int , 3 , 4 > m4;
m4 = m1;
for ( int i = 0 ; i < 3 ; i ++ )
for ( int j = 0 ; j < 4 ; j ++ )
{
assert (m1. get (i, j) == 0 );
assert (m2. get (i, j) == 0 );
assert (m3. get (i, j) == 0 );
assert (m4. get (i, j) == 0 );
}
int a[] = { 1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 , 10 , 11 , 12 };
Matrix < int , 3 , 4 > m5(a);
int b[ 3 ][ 4 ] = { { 1 , 2 , 3 , 4 },
{ 5 , 6 , 7 , 8 },
{ 9 , 10 , 11 , 12 }};
Matrix < int , 3 , 4 > m6(b);
Matrix < int , 3 , 4 > m7(m5);
Matrix < int , 3 , 4 > m8 = m5;
Matrix < int , 3 , 4 > m9;
m9 = m5;
for ( int i = 0 ; i < 3 ; i ++ )
for ( int j = 0 ; j < 4 ; j ++ )
{
assert (m5. get (i, j) == i * 4 + j + 1 );
assert (m6. get (i, j) == i * 4 + j + 1 );
assert (m7. get (i, j) == i * 4 + j + 1 );
assert (m8. get (i, j) == i * 4 + j + 1 );
assert (m9. get (i, j) == i * 4 + j + 1 );
}
// 维数不匹配,编译错误
// Matrix<int, 4, 5> m10 = m9;
int c[][ 2 ] = {{ 1 , 2 }, { 2 , 3 }};
// 数组大小不匹配,编译错误
// Matrix<int, 3, 4> m10(c);
int d[] = { 1 , 2 };
// 数组大小不匹配,编译错误
// Matrix<int, 3, 4> m11(d);
// 乘法维数不合适,无法相乘
// m1 * m2;
Matrix < int , 4 , 3 > m12;
// 匹配,可以相乘
Matrix < int , 3 , 3 > m13 = m1 * m12;
Matrix < int , 8 , 3 > m14;
// 无法相乘
// Matrix<int, 3, 3> m15 = m1 * m14;
// 可以相乘
Matrix < int , 8 , 4 > m15 = m14 * m1;
}
{
// 检查点乘
int a[ 2 ][ 5 ] = {{ 1 , 2 , 3 , 4 , 5 }, { 6 , 7 , 8 , 9 , 10 }};
Matrix < int , 2 , 5 > m1(a);
int b[ 5 ][ 3 ] = {{ 1 , 2 , 3 }, { 4 , 5 , 6 }, { 7 , 8 , 9 }, { 10 , 11 , 12 }, { 13 , 14 , 15 }};
Matrix < int , 5 , 3 > m2(b);
int c[ 2 ][ 3 ] = {{ 135 , 150 , 165 }, { 310 , 350 , 390 }};
Matrix < int , 2 , 3 > m3(c);
Matrix < int , 2 , 3 > m4 = m1 * m2;
assert(m4 == m3);
cout << m4. get ( 0 , 0 ) << endl;
}
return 0 ;
}
补充:
1、加法、减法只需要2个矩阵维度相同即可。
template
<
class
T,
class
R,
class
C
>
Matrix < T,R,C > operator + ( const Matrix < T,R,C >& lhs, const Matrix < T,R,C >& rhs)
{
//
}
Matrix < T,R,C > operator + ( const Matrix < T,R,C >& lhs, const Matrix < T,R,C >& rhs)
{
//
![dot.gif](/Images/dot.gif)
}
2、由于1x1的矩阵可以看成一个标量,矩阵与标量运算结果维数与原矩阵相同,可以重载来实现。
template
<
class
T,
class
R,
class
C
>
Matrix < T,R,C > operator * ( const Matrix < T,R,C >& lhs, const Matrix<T,1,1> & rhs)
{
//
}
Matrix < T,R,C > operator * ( const Matrix < T,R,C >& lhs, const Matrix<T,1,1> & rhs)
{
//
![dot.gif](/Images/dot.gif)
}
3、由于类型泛化,可能某些合理的运算无法进行,比如float型矩阵,与一个int型标量运算等。这些最好是借助类型萃取等手段,推导出运算以后的类型。(c++0x中包含自动获取运算结果类型的关键字typeof,等几年就可以用了:)。GCC编译器中已有实现,不过似乎有BUG)。
4、其它。泛型实现可能会有一些考虑不周的地方,强类型有强类型的好处,不过必须要有完整的泛型算法支撑,否则难以使用。也可以把泛型矩阵类从一个普通矩阵类派生,这样更容易写出通用算法,不过在实现上可能要借助于运行期多态,对于矩阵类来说并不合适。
5、其它。。前面说C++的模板相当强大,D语言模板到目前为止似乎已经完全实现了C++模板的功能,还增加了一些比如字符串值参模板等特性,比C++模板功能更多。在代码编写上,由于可以编写静态判断语句(编译期)以及静态断言,编写模板比C++更容易。有时间可以试试用它写个矩阵类,纯粹是兴趣,这些东西真的很难用到,现成的库也挺多。
6、其它。。。c++0x要提供“template typedef”,也就是可以这样定义:
template <int R, int C> typedef Matrix<int, R, C> MatrixInt; // 定义类型,维度不定
template <class T> typedef Matrix<T, 4, 4> Matrix4x4; // 定义维度,类型不定
由此可以出定义行向量、列向量、标量等,当然实际使用起来可能没那么舒服了。