目的
- 泛型为了实现int、double等矩阵的运算。
问题
- 运算符重载时,两个不同类型、行数、列数的矩阵如何实现用泛型记录类型?这里采用模板+友元函数的方式实现。
代码
#include <iostream>
#include <vector>
using namespace std;
template<typename T, int r, int c>
class TMatrix{
public:
TMatrix() : row(r), colum(c), elem(r, vector<T>(c)) {}
TMatrix(T e[]) : row(r), colum(c), elem(r, vector<T>(c)) {
for(int i = 0; i < row; i++){
for(int j = 0; j < colum; j++){
elem[i][j] = e[i * colum + j];
}
}
}
template<typename T1, int r1, int c1, typename T2, int r2, int c2>
friend TMatrix<T1, r1, c1> operator+(const TMatrix<T1, r1, c1>& t1, const TMatrix<T2, r2, c2>& t2);
template<typename T1, int r1, int c1, typename T2, int r2, int c2>
friend TMatrix<T1, r1, c2> operator*(const TMatrix<T1, r1, c1>& m1, const TMatrix<T2, r2, c2>& m2);
template<typename T1, int r1, int c1>
friend void printMatrix(const TMatrix<T1, r1, c1>& m);
protected:
int row, colum;
vector<vector<T>> elem;
};
template<typename T1, int r1, int c1, typename T2, int r2, int c2>
TMatrix<T1, r1, c1> operator+(const TMatrix<T1, r1, c1>& t1, const TMatrix<T2, r2, c2>& t2){
T1 e1; T2 e2;
TMatrix<decltype(e1 * e2), r1, c1> m3;
if(r1 == 0 || c1 == 0 || !(r1 == r2 && c1 == c2)){
return m3;
}
for(int i = 0; i < r1; i++){
for(int j = 0; j < c1; j++){
m3.elem[i][j] = t1.elem[i][j] + t2.elem[i][j];
}
}
return m3;
}
template<typename T1, int r1, int c1, typename T2, int r2, int c2>
TMatrix<T1, r1, c2> operator*(const TMatrix<T1, r1, c1>& m1, const TMatrix<T2, r2, c2>& m2){
T1 e1; T2 e2;
decltype(e1 * e2) val;
TMatrix<decltype(val), r1, c2> m3;
if(r1 == 0 || c2 == 0 || c1 == 0 || r1 == 0 || m1.colum != m2.row){
return m3;
}
for(int i = 0; i < r1; i++){
for(int j = 0; j < c2; j++){
val = 0;
for(int k = 0; k < c1; k++){
val += m1.elem[i][k] * m2.elem[k][j];
}
m3.elem[i][j] = val;
}
}
return m3;
}
template<typename T1, int r1, int c1>
void printMatrix(const TMatrix<T1, r1, c1>& m){
for(int i = 0; i < r1; i++){
for(int j = 0; j < c1; j++){
cout << m.elem[i][j] << " ";
}
cout << endl;
}
cout << endl;
}
int main()
{
double a[] = {1.5, 2, 3, 4.2, 5, 6};
TMatrix<double, 3, 2> t1(a);
TMatrix<double, 3, 2> t2(a);
TMatrix<double, 3, 2> t3(a);
TMatrix<double, 2, 3> t4(a);
printMatrix(t1 + t2);
cout << endl;
printMatrix(t3 * t4);
cout << "Hello" << endl;
return 0;
}