C++ Overload << and >>
前言
本篇基於Overloading stream insertion (<>) operators in C++,並加上TensorRT中的例子做為參考。
Overload << and >>
在C++中,<<
被稱為"stream insertion operator",用於輸出;>>
被稱為"stream extraction operator",用於輸入。
要overload這兩個運算子,我們有兩種選擇:
- 使它們成為"運算子左邊的運算元"的成員函數
- 使它們成為global function
如果採用第一種選擇,首先要知道的是:
<<
左邊的運算元,如cout
,是std::ostream
型別的物件<<
左邊的運算元,如cin
,是std::istream
型別的物件
這表示我們得重寫ostream
或istream
類別(它們是"運算子左邊的運算元"),為它們新增overload過的成員函數,這很明顯不是一個好選擇。
Overload as global function
因此通常的做法是將它們定義為global function。這個函數共接受兩個參數:一為istream
或ostream
物件本身,二為它們要輸入或輸出的物件。如下例:
#include <iostream>
using namespace std;
class Complex
{
private:
int real, imag;
public:
Complex(int r = 0, int i =0)
{ real = r; imag = i; }
friend ostream & operator << (ostream &out, const Complex &c);
friend istream & operator >> (istream &in, Complex &c);
};
ostream & operator << (ostream &out, const Complex &c)
{
out << c.real;
out << "+i" << c.imag << endl;
return out;
}
istream & operator >> (istream &in, Complex &c)
{
cout << "Enter Real Part ";
in >> c.real;
cout << "Enter Imagenory Part ";
in >> c.imag;
return in;
}
int main()
{
Complex c1;
cin >> c1;
cout << "The complex object is ";
cout << c1;
return 0;
}
觀察<<
函數的簽名:
ostream & operator << (ostream &out, const Complex &c)
我們可以發現幾點:
- 其第一個參數為
ostream
物件本身,以參考方式傳入 - 第二個參數為欲輸出的物件,同樣以參考方式傳入,前面加入
const
,防止它被意外修改 - 函數會回傳
ostream
物件,所以我們才能連續輸出,如:cout << a << b;
。並且注意到它是以"參考"方式回傳,這樣我們才不必每次創建一個新的ostream
物件來用。
觀察>>
函數的簽名:
istream & operator >> (istream &in, Complex &c)
寫法與<<
基本一致,唯一要注意的是現在函數的第二個參數沒有const
修飾字了。
另外注意到在Complex
這個類別中,將<<
及>>
兩個overload的函數宣告為friend function。這使得此二函數能有存取Complex
的私有成員變數real
及imag
的權限。
TensorRT中的例子
在TensorRT/parsers/common/parserUtils.h
中overload了<<
這個運算子:
inline std::ostream& operator<<(std::ostream& o, const nvinfer1::Dims& dims)
{
o << "[";
for (int i = 0; i < dims.nbDims; i++)
o << (i ? "," : "") << dims.d[i];
o << "]";
return o;
}
TensorRT/include/NvInferRuntimeCommon.h
中Dims
類別的定義:
class Dims
{
public:
static const int MAX_DIMS = 8; //!< The maximum number of dimensions supported for a tensor.
int nbDims; //!< The number of dimensions.
int d[MAX_DIMS]; //!< The extent of each dimension.
TRT_DEPRECATED DimensionType type[MAX_DIMS]; //!< The type of each dimension.
};
因為nbDims
及d
都是Dims
的public成員變數,所以<<
可以自由地存取它們。因此我們不必在Dims
裡將<<
宣告為friend function。