表达式模板是一种C++模板元编程(template metaprogram)技术。典型情况下,表达式模板自身代表一种操作,模板参数代表该操作的操作数。模板表达式可将子表达式的计算推迟,这样 有利于优化(特别是减少临时变量的使用)。表达式模板也可以作为参数传递给一个函数。
例子:我们实现一个用来求表达式 x = 1.2*x + x*y 的模板表达式,其中x、y为数组
//exprarray.h
#include <stddef.h>
#include <cassert>
#include "sarray.h"
template<typename T>
class A_Scale
{
public:
A_Scale(T const& t):value(t){}
T operator[](size_t) const
{
return value;
}
size_t size() const
{
return 0;
}
private:
T const& value;
};
template<typename T>
class A_Traits
{
public:
typedef T const& exprRef;
};
template<typename T>
class A_Traits<A_Scale<T> >
{
public:
typedef A_Scale<T> exprRef;
};
template<typename T,typename L1,typename R2>
class A_Add
{
private:
typename A_Traits<L1>::exprRef op1;
typename A_Traits<R2>::exprRef op2;
public:
A_Add(L1 const& a,R2 const& b):op1(a),op2(b)
{
}
T operator[](size_t indx) const
{
return op1[indx] + op2[indx];
}
size_t size() const
{
assert(op1.size()==0 || op2.size()==0 || op1.size() == op2.size());
return op1.size() != 0 ? op1.size() : op2.size();
}
};
template<typename T,typename L1,typename R2>
class A_Mul
{
private:
typename A_Traits<L1>::exprRef op1;
typename A_Traits<R2>::exprRef op2;
public:
A_Mul(L1 const& a,R2 const& b):op1(a),op2(b)
{
}
T operator[](size_t indx) const
{
return op1[indx] * op2[indx];
}
size_t size() const
{
assert(op1.size()==0 || op2.size()==0 || op1.size() == op2.size());
return op1.size() != 0 ? op1.size():op2.size();
}
};
template<typename T,typename Rep = SArray<T> >
class Array
{
public:
explicit Array(size_t N):expr_Rep(N){}
Array(Rep const& rep):expr_Rep(rep){}
Array& operator=(Array<T> const& orig)
{
assert(size() == orig.size());
for (size_t indx=0;indx < orig.size();indx++)
{
expr_Rep[indx] = orig[indx];
}
return *this;
}
template<typename T2,typename Rep2>
Array& operator=(Array<T2,Rep2> const& orig)
{
assert(size() == orig.size());
for (size_t indx=0;indx<orig.size();indx++)
{
expr_Rep[indx] = orig[indx];
}
return *this;
}
size_t size() const
{
return expr_Rep.size();
}
T operator[](size_t indx) const
{
assert(indx < size());
return expr_Rep[indx];
}
T& operator[](size_t indx)
{
assert(indx < size());
return expr_Rep[indx];
}
Rep const& rep() const
{
return expr_Rep;
}
Rep& rep()
{
return expr_Rep;
}
private:
Rep expr_Rep;
};
template<typename T,typename L1,typename R2>
Array<T,A_Add<T,L1,R2> >
operator+(Array<T,L1> const& a,Array<T,R2> const& b)
{
return Array<T,A_Add<T,L1,R2> >(A_Add<T,L1,R2>(a.rep(),b.rep()));
}
template<typename T,typename L1,typename R2>
Array<T,A_Mul<T,L1,R2> >
operator*(Array<T,L1> const& a,Array<T,R2> const& b)
{
return Array<T,A_Mul<T,L1,R2> >(A_Mul<T,L1,R2>(a.rep(),b.rep()));
}
template<typename T,typename R2>
Array<T,A_Mul<T,A_Scale<T>,R2> >
operator*(T const& a,Array<T,R2> const& b)
{
return Array<T,A_Mul<T,A_Scale<T>,R2> >(A_Mul<T,A_Scale<T>,R2>(A_Scale<T>(a),b.rep()));
}
测试代码(求解表达式1.2*x+x*y):
//test.cpp
#include "exprarray.h"
#include <iostream>
using namespace std;
template <typename T>
void print (T const& c)
{
for (int i=0; i<8; ++i) {
std::cout << c[i] << ' ';
}
std::cout << "..." << std::endl;
}
int main()
{
Array<double> x(1000), y(1000);
for (int i=0; i<1000; ++i) {
x[i] = i;
y[i] = x[i]+x[i];
}
std::cout << "x: ";
print(x);
std::cout << "y: ";
print(y);
x = 1.2 * x;
std::cout << "x = 1.2 * x: ";
print(x);
x = 1.2*x + x*y;
std::cout << "1.2*x + x*y: ";
print(x);
x = y;
std::cout << "after x = y: ";
print(x);
return 0;
}
下面我们来分析一下模板表达式的解析过程:
我们以表达式 x = 1.2*x + x*y为例
当编译器解析表达式:x = 1.2*x + x*y 的时候,编译器首先会应用最左边的*运算符,它是一个Scale-Array运算符。于是重载解析规则将会选择operator*的Scale-Array形式:
template<typename T,typename R2>
Array<T,A_Mul<T,A_Scale<T>,R2> >
operator*(T const& a,Array<T,R2> const& b)
{
return Array<T,A_Mul<T,A_Scale<T>,R2> >(A_Mul<T,A_Scale<T>,R2>(A_Scale<T>(a),b.rep()));
}
其中操作数的类型是double和Array<double,SArray<double> >,因此实际的结果类型是:
Array<double,A_Mul<double,A_Scale<double>,SArray<double> > >
接下来,编译器会对第二个乘法进行求值:x*y是一个array-array操作,这一次,我们将会选择operator*的Array-Array重载操作:
template<typename T,typename L1,typename R2>
Array<T,A_Mul<T,L1,R2> >
operator*(Array<T,L1> const& a,Array<T,R2> const& b)
{
return Array<T,A_Mul<T,L1,R2> >(A_Mul<T,L1,R2>(a.rep(),b.rep()));
}
其中两个操作数类型都是Array<double,SArray<double> >,因此结果类型为:
Array<double,A_Mul<double,SArray<double>,SArray<double> > >
这一次,A_Mul所封装的连个参数对象都引用了一个SArray<double>表示:即一个表示x对象,一个表示y对象。
现在开始对+运算符进行求值。这次还是Array-Array操作,因此调用Array-Array版本的operator+:
template<typename T,typename L1,typename R2>
Array<T,A_Add<T,L1,R2> >
operator+(SArray<T,L1> const& a,SArray<T,R2> const& b)
{
return Array<T,L1,R2>(A_Add<T,L1,R2>(a.rep(),b.rep()));
}
其中用double来替换T,则R1为:
A_Mul<double,A_Scale<double>,SArray<double> >
R2为:
A_Mul<double,SArray<double>,SArray<double> >
因此赋值表达式 x = 1.2*x + x*y的右边经过编译器解析后的最终类型为:
Array<double,
A_Add<double,
A_Mul<double,A_Scale<double>,SArray<double> >
A_Mul<double,SArray<double>,SArray<double> > > >
这个类型将与Array模板的赋值运算符模板进行匹配:
//针对不同类型数组的赋值运算符
template<typename T2,typename Rep2>
Array& operator=(Array<T2,Rep2> const& orig)
{
assert(size() == orig.size());
for (size_t indx=0;indx<orig.size();indx++)
{
expr_Rep[indx] = orig[indx];
}
return *this;
}
此时,赋值运算符将会运用右边Array的下标运算符来计算目标数组的每一个元素,而Array的实际类型为:
Array<double,
A_Add<double,
A_Mul<double,A_Scale<double>,SArray<double> >
A_Mul<double,SArray<double>,SArray<double> > > >
我们记为:ArrayTgt
此时,ArrayTgt[indx]将会匹配模板类A_Add中的重载操作符operator[],即:
T operator[](size_t indx) const
{
return op1[indx] + op2[indx];
}
匹配之后就变成:
A_Mul<double,A_Scale<double>,SArray<double> >[indx]
+
A_Mul<double,SArray<double>,SArray<double> >[indx];
而A_Mul[indx]又会匹配模板类A_Mul中的重载操作符operator[],即:
T operator[](size_t indx) const
{
return op1[indx] * op2[indx];
}
匹配之后就变成:
A_Scale<double>[indx] * SArray<double>[indx]
+
SArray<double>[indx] * SArray<double>[indx]
而A_Scale[indx]又会匹配模板类A_Scale中的重载操作符operator[],即:
T operator[](size_t) const
{
return value;
}
这样最终的结果就表达式就变成:
value[indx] * SArray<double>[indx]
+
SArray<double>[indx] * SArray<double>[indx]
至此,整个模板表达式的解析工作已经完成,只需进行计算即可。在整个计算过程中,没有产生任何的中间变量,所以程序的效率得以大幅的提高。
程序注意事项:
1.在上述代码中,如果将模板类Array的代码:
Array& operator=(Array<T2,Rep2> const& orig)
{
assert(size() == orig.size());
for (size_t indx=0;indx<orig.size();indx++)
{
expr_Rep[indx] = orig[indx];
}
return *this;
}
中的参数改为Array<T2,Rep2> & orig,即变成:
Array& operator=(Array<T2,Rep2>& orig)
{
assert(size() == orig.size());
for (size_t indx=0;indx<orig.size();indx++)
{
expr_Rep[indx] = orig[indx];
}
return *this;
}
将会导致编译出错,原因是:
在test.cpp文件中,我们使用了表达式:x = 1.2 * x ,这个表达式的右边将会被编译器解析为如下形式的表达式:
Array<double,A_Mul<double,A_Scale<double>,SArray<doube> > >
这样在进行重载操作符operator[]的匹配时,将会变成如下形式:
SArray[indx] = A_Scale[indx] * SArray[indx]
到了这一步,问题就出现了,因为A_Scale[indx]会匹配模板类Array中的重载操作符operator[],但是我们发现在模板类Array代码中,有两个重载的operator[],即:
T operator[](size_t indx) const
{
assert(indx < size());
return expr_Rep[indx];
}
T& operator[](size_t indx)
{
assert(indx < size());
return expr_Rep[indx];
}
如果我们没在重载操作符operator=的参数中写入const的话,这里会优先调用无const的operator[]重载函数,但是A_Scale[indx]是个常数,在本例中也就是一个double类型,这样最后在调用operator[]返回的时候就出现了类型不匹配的现象,因为无const的operator[]返回的类型是double&,所以会报错。当然,我们可以将test.cpp程序中的表达式1.2*x去掉,我们会发现,这个时候无const的operator=就会编译通过。
2.在上述代码中,模板类Array的构造函数代码为:
explicit Array(size_t N):expr_Rep(N){}
这表明定义Array必须通过显式转型,不能通过隐式转型。下述代码会导致编译出错:
Array a = 5;
我们只能使用
Array a(5);
进行显式初始化。
下面用一个例子来区别显式转型和隐式转型的细微区别:
X x;
Y y(x); //显式转型
Y y = x;//隐式转型
其中前者通过使用从X到Y类型的显式转型,新建一个类型为Y的对象。后者使用了一个从类型X到Y类型的隐式转型,新建了一个类型Y的对象。
3.在上述代码中,模板类Array的两个重载操作符operator[]代码:
T operator[](size_t indx) const
{
assert(indx < size());
return expr_Rep[indx];
}
T& operator[](size_t indx)
{
assert(indx < size());
return expr_Rep[indx];
}
注意在一个重载操作符函数后面的const一定不能少,否则会导致编译错误。因为没有const的话,函数
T operator[](size_t indx)
{
assert(indx < size());
return expr_Rep[indx];
}
和
T& operator[](size_t indx)
{
assert(indx < size());
return expr_Rep[indx];
}
会被认为是一个函数,因为他们静静是返回类型不同而已。函数
int test(){}
和
int test() const{}
会被编译器理解为两个不同的函数。
最后将SArray的代码附上:
#ifndef SARRAY_H
#define SARRAY_H
#include <stddef.h>
#include <cassert>
template<typename T>
class SArray
{
public:
explicit SArray(size_t N):ptr(new T[N]),_size(N)
{
init();
}
SArray(SArray<T> const& orig):ptr(new T[orig.size()]),_size(orig.size())
{
copy(orig);
}
~SArray()
{
delete[] ptr;
}
size_t size() const
{
return _size;
}
T operator[](size_t indx) const
{
return ptr[indx];
}
T& operator[](size_t indx)
{
return ptr[indx];
}
SArray<T>& operator=(SArray<T> const& orig)
{
if (&orig != this)
{
copy(orig);
}
return *this;
}
protected:
void copy(SArray<T> const& orig)
{
assert(size() == orig.size());
for (size_t indx=0;indx<orig.size();indx++)
{
ptr[indx] = orig[indx];
}
}
void init()
{
for(size_t i=0;i<size();i++)
{
ptr[i] = T();
}
}
private:
T* ptr;
size_t _size;
};
#endif