之前在用ceres做Pose Graph Optimization时,碰到一个痛点。ceres
内部接口使用指针来传递数组,比如四元数是float* quat_ptr
,要访问四元数中的特定元素,我们需要知道它是以wxyz
还是xyzw
排列,这种隐性的约定很容易搞错;直接访问数组易犯越界访问的错误,比如一不小心把quat_ptr[3]
敲成了quat_ptr[4]
。Eigen
提供了Eigen::Map<const Eigen::Quaternion<T>>
来访问该四元数,可以直接使用wrapper.x()
来直接访问数组元素,不需要记住x
元素是数组的第几位。但是有时四元数不按照Eigen
的方式排列时,我们就要手动将其转换成Eigen::Map
所需的形式。因此我们需要一种泛化的数组指针wrapper,它提供按照数组元素名称对各元素进行访问的接口,它还得是zero-overhead的,用它对数组访问的速度必须和通过指针一样快。本着闲着也是闲着不如干点什么的原则,开发了NamedVectorWrapper
。
最终效果
首先来看一下NamedVectorWrapper
的用例:
// 使用NamedVectorWrapper是一个模版类.
// 模版第一个元素是数组的元素类型
// 形参包(parameter pack)Ts是数组每个元素的名字类。
template <typename ValueType, typename... Ts> class NamedVectorWrapper;
// 为每个数组元素定义类。
class X_t {}; constexpr static X_t X;
class Y_t {}; constexpr static Y_t Y;
class Z_t {}; constexpr static Z_t Z;
class W_t {}; constexpr static W_t W;
// 创建一个按照WXYZ排列的四元数Wrapper,按照名字进行访问和修改数组元素。
using QuaternionWXYZWrapper = NamedVectorWrapper<double,W_t,X_t,Y_t,Z_t>;
double quaternion_array[4]={1.0,2.0,3.0,4.0};
QuaternionWXYZWrapper quat_wrapper(quaternion_array);
assert(quat_wrapper[X] == 2.0);
quat_wrapper[X] = 99.0;
assert(quat_wrapper[X] == 99.0);
// 创建一个按照WXYZ排列的常量四元数Wrapper,按照名字进行访问和修改数组元素。
using ConstQuaternionWXYZWrapper = NamedVectorWrapper<const double,W_t,X_t,Y_t,Z_t>;
const double const_quaternion_array[4]={1.0,2.0,3.0,4.0};
ConstQuaternionWXYZWrapper const_quat_wrapper(const_quaternion_array);
assert(const_quat_wrapper[X] == 2.0);
// const_quat_wrapper[X] = 2.0; compile error,nice!
通过上面的例子,我们看到NamedVectorWrapper
用起貌似挺方便的,自夸一下哈哈。但是为每个名字定义类着实蛮烦,容易造成代码里到处都是这种临时的类,最好能够直接使用字符串来访问,幸运的是,C++的自定义字面量提供了这个功能(User Defined Literals),我们自定义了operator "" _t()
,它能创建由字符串定义的类。有了它的助力,我们的代码变成了下面的样子。完美!
using QuaternionWXYZWrapper = NamedVectorWrapper<double,
decltype("w"_t),
decltype("x"_t),
decltype("y"_t),
decltype("z"_t)>;
double quaternion_array[4]={1.0,2.0,3.0,4.0};
QuaternionWXYZWrapper quat_wrapper(quaternion_array);
assert(quat_wrapper["w"_t] == 1.0);
assert(quat_wrapper["x"_t] == 2.0);
assert(quat_wrapper["y"_t] == 3.0);
assert(quat_wrapper["z"_t] == 4.0);
NamedVectorWrapper具体实现
要实现NamedVectorWrapper的通过类名访问数组的元素,必须要有将类名转换到数组的index的功能,这个功能有两个基本需求。第一,查询名字类队列(Type List)中某个元素第一次出现的位置。第二,判定名字类队列(type list)互异,即每个类各不相同,否则我们无法判段用户到底是指哪一个元素,比如我们无法决定X_t
是第零个元素还是第二个元素 <X_t,Y_t,X_t>
,所以我们需要提供判断类队列是否各不相同的函数。这两个功能必须编译期执行,不降低程序的runtime效率,这就要用到我们又爱又恨的meta programming。一个锦上添花的功能是通过字符串常量来构造类。
判断type list互异
通过递归地化模版来判断是否互异。
// --------------检查type list中的类型是否互异------------
// General template
template <typename... T>
struct is_distinct_type_list;
// Template Specialization:当只有一个类时,type list肯定是互异的,所以值为true。
template <typename T>
struct is_distinct_type_list<T>
{
static constexpr bool value = true;
};
// Template Specialization:当有多个类时,先判断第一个类和剩余的类不同,再判断剩余的元素也互异
template <typename T, typename... Ts>
struct is_distinct_type_list<T, Ts...>
{
// 检车T和Ts中每一个类型都不同,且Ts中每个元素互不相同。
static constexpr bool value =
std::conjunction_v<std::negation<std::disjunction<std::is_same<T,Ts>...>>,
is_distinct_type_list<Ts...>>;
};
// ---------------------测试---------------------
void is_distinct_type_list_test()
{
static_assert(is_distinct_type_list<decltype("x"_t)>::value);
static_assert(is_distinct_type_list<decltype("x"_t),decltype("y"_t)>::value);
static_assert(is_distinct_type_list<decltype("x"_t),decltype("y"_t), decltype("z"_t)>::value);
static_assert(std::negation_v<is_distinct_type_list<decltype("x"_t),decltype("x"_t)>>);
static_assert(std::negation_v<is_distinct_type_list<decltype("x"_t),decltype("y"_t), decltype("x"_t)>>);
static_assert(std::negation_v<is_distinct_type_list<decltype("y"_t),decltype("y"_t), decltype("z"_t)>>);
static_assert(std::negation_v<is_distinct_type_list<decltype("x"_t),decltype("y"_t), decltype("y"_t)>>);
static_assert(is_distinct_type_list<int,bool,float,double,std::string,decltype("x"_t)>::value);
static_assert(std::negation_v<is_distinct_type_list<int,bool,float,float,std::string,decltype("x"_t)>>);
static_assert(std::negation_v<is_distinct_type_list<int,bool,float,double,int,std::string,decltype("x"_t)>>);
static_assert(std::negation_v<is_distinct_type_list<int,bool,float,double,float,std::string,decltype("x"_t)>>);
static_assert(std::negation_v<is_distinct_type_list<int,bool,float,double,bool,std::string,decltype("x"_t)>>);
}
在type list中找出某个type的位置
通过递归地具体化模版来找出类在类队列中的位置。
// -------------计算QueryType在TypeTuple中的第一次出现的位置----------------
// General template
template<typename QueryType, typename TypeTuple>
struct tuple_element_index_helper;
// Template Specialization:当TypeTuple为空,默认值为0。
template <typename QueryType>
struct tuple_element_index_helper<QueryType, std::tuple<>>
{
static constexpr size_t value = 0;
};
// Template Specialization:当TypeTuple中第一个元素就为QueryType,值为0。
template <typename QueryType, typename... RestTypes>
struct tuple_element_index_helper<QueryType, std::tuple<QueryType, RestTypes...>>
{
static constexpr size_t value = 0;
};
// Template Specialization:当TypeTuple中第一个元素不是QueryType,
// 值为1+QueryType在RestTuple中的位置。
template <typename QueryType, typename FirstType, typename... RestTypes>
struct tuple_element_index_helper<QueryType, std::tuple<FirstType, RestTypes...>>
{
using RestTuple = std::tuple<RestTypes...>;
static constexpr size_t value = 1 + tuple_element_index_helper<QueryType, RestTuple>::value;
};
// ----------------------帮助函数,隐藏细节--------------------
template <typename QueryType, typename TypeTuple>
constexpr size_t get_type_index()
{
constexpr size_t index = tuple_element_index_helper<QueryType, TypeTuple>::value;
// 如果QueryType不在TypeTuple中,编译错误。
static_assert(index < std::tuple_size_v<TypeTuple>);
return index;
}
void type_index_in_typetuple_test()
{
static_assert(get_type_index<int, std::tuple<int>>() == 0);
// static_assert(get_type_index<int, std::tuple<double>>() == 0); // compile error
static_assert(get_type_index<double, std::tuple<double,int,double>>() == 0);
static_assert(get_type_index<double, std::tuple<int, double,double>>() == 1);
static_assert(get_type_index<double, std::tuple<double,int>>() == 0);
static_assert(get_type_index<int, std::tuple<double,int>>() == 1);
static_assert(get_type_index<bool, std::tuple<bool, double,int>>() == 0);
static_assert(get_type_index<double, std::tuple<bool, double,int>>() == 1);
static_assert(get_type_index<int, std::tuple<bool, double,int>>() == 2);
using StringTypeTestTuple = std::tuple<decltype("x"_t), decltype("xy"_t),decltype("y"_t),decltype("z"_t)>;
static_assert(get_type_index<decltype("x"_t), StringTypeTestTuple>() == 0);
static_assert(get_type_index<decltype("xy"_t), StringTypeTestTuple>() == 1);
static_assert(get_type_index<decltype("y"_t), StringTypeTestTuple>() == 2);
static_assert(get_type_index<decltype("z"_t), StringTypeTestTuple>() == 3);
// get_type_index<decltype("ttt"_t), StringTypeTestTuple>(); // compile error
}
字符串定义的自定义字面量
通过字符串来定义一个字面量,字面量的类型由字符串具体化。
// 由char pack定义的模版类。
template <char... Chars>
class StringDefinedType {};
// 自定义字面量(User Defined Literals)来返回一个由字符串具体化的模版实例。
// 这个语法string literal operator templates 是GNU的扩展, 不是C++标准。
template <typename Char, Char...chars> constexpr StringDefinedType<chars...> operator "" _t()
{
return {};
}
// 编译期测试StringDefinedType
void string_defined_type_test()
{
// 由相同字符串定义的StringDefinedType类型是相同的
static_assert(std::is_same_v<decltype("x"_t), decltype("x"_t)>);
static_assert(std::is_same_v<decltype("ab"_t), decltype("ab"_t)>);
static_assert(std::is_same_v<decltype("ABC"_t), decltype("ABC"_t)>);
// 由不同字符串定义的StringDefinedType类型是相同的
static_assert(std::negation_v<std::is_same<decltype("x"_t), decltype("b"_t)>>);
static_assert(std::negation_v<std::is_same<decltype("ABC"_t), decltype("b"_t)>>);
static_assert(std::negation_v<std::is_same<decltype("a"_t), decltype("ab"_t)>>);
}
运行速度
zero-overhead,和直接使用数组索引一样快!
reference
Modern C++ Design: Generic Programming and Design Patterns Applied
C++ Templates: The Complete Guide
帖子中代码的汇总
#include <iostream>
#include <type_traits>
#include <tuple>
//------------------------------------------------------------------------
// 由char pack定义的模版类。
template <char... Chars>
class StringDefinedType {};
// 自定义字面量(User Defined Literals)来返回一个由字符串具体化的模版实例。
// 这个语法string literal operator templates 是GNU的扩展, 不是C++标准。
template <typename Char, Char...chars> constexpr StringDefinedType<chars...> operator "" _t()
{
return {};
}
// 编译期测试StringDefinedType
void string_defined_type_test()
{
// 由相同字符串定义的StringDefinedType类型是相同的
static_assert(std::is_same_v<decltype("x"_t), decltype("x"_t)>);
static_assert(std::is_same_v<decltype("ab"_t), decltype("ab"_t)>);
static_assert(std::is_same_v<decltype("ABC"_t), decltype("ABC"_t)>);
// 由不同字符串定义的StringDefinedType类型是相同的
static_assert(std::negation_v<std::is_same<decltype("x"_t), decltype("b"_t)>>);
static_assert(std::negation_v<std::is_same<decltype("ABC"_t), decltype("b"_t)>>);
static_assert(std::negation_v<std::is_same<decltype("a"_t), decltype("ab"_t)>>);
}
//------------------------------------------------------------------------
// --------------检查type list中的类型是否互异------------
// General template
template <typename... T>
struct is_distinct_type_list;
// Template Specialization:当只有一个类时,type list肯定是互异的,所以值为true。
template <typename T>
struct is_distinct_type_list<T>
{
static constexpr bool value = true;
};
// Template Specialization:当有多个类时,先判断第一个类和剩余的类不同,再判断剩余的元素也互异
template <typename T, typename... Ts>
struct is_distinct_type_list<T, Ts...>
{
// 检车T和Ts中每一个类型都不同,且Ts中每个元素互不相同。
static constexpr bool value =
std::conjunction_v<std::negation<std::disjunction<std::is_same<T,Ts>...>>,
is_distinct_type_list<Ts...>>;
};
// ---------------------测试---------------------
void is_distinct_type_list_test()
{
static_assert(is_distinct_type_list<decltype("x"_t)>::value);
static_assert(is_distinct_type_list<decltype("x"_t),decltype("y"_t)>::value);
static_assert(is_distinct_type_list<decltype("x"_t),decltype("y"_t), decltype("z"_t)>::value);
static_assert(std::negation_v<is_distinct_type_list<decltype("x"_t),decltype("x"_t)>>);
static_assert(std::negation_v<is_distinct_type_list<decltype("x"_t),decltype("y"_t), decltype("x"_t)>>);
static_assert(std::negation_v<is_distinct_type_list<decltype("y"_t),decltype("y"_t), decltype("z"_t)>>);
static_assert(std::negation_v<is_distinct_type_list<decltype("x"_t),decltype("y"_t), decltype("y"_t)>>);
static_assert(is_distinct_type_list<int,bool,float,double,std::string,decltype("x"_t)>::value);
static_assert(std::negation_v<is_distinct_type_list<int,bool,float,float,std::string,decltype("x"_t)>>);
static_assert(std::negation_v<is_distinct_type_list<int,bool,float,double,int,std::string,decltype("x"_t)>>);
static_assert(std::negation_v<is_distinct_type_list<int,bool,float,double,float,std::string,decltype("x"_t)>>);
static_assert(std::negation_v<is_distinct_type_list<int,bool,float,double,bool,std::string,decltype("x"_t)>>);
}
//------------------------------------------------------------------------
// -------------计算QueryType在TypeTuple中的第一次出现的位置----------------
// General template
template<typename QueryType, typename TypeTuple>
struct tuple_element_index_helper;
// Template Specialization:当TypeTuple为空,默认值为0。
template <typename QueryType>
struct tuple_element_index_helper<QueryType, std::tuple<>>
{
static constexpr size_t value = 0;
};
// Template Specialization:当TypeTuple中第一个元素就为QueryType,值为0。
template <typename QueryType, typename... RestTypes>
struct tuple_element_index_helper<QueryType, std::tuple<QueryType, RestTypes...>>
{
static constexpr size_t value = 0;
};
// Template Specialization:当TypeTuple中第一个元素不是QueryType,
// 值为1+QueryType在RestTuple中的位置。
template <typename QueryType, typename FirstType, typename... RestTypes>
struct tuple_element_index_helper<QueryType, std::tuple<FirstType, RestTypes...>>
{
using RestTuple = std::tuple<RestTypes...>;
static constexpr size_t value = 1 + tuple_element_index_helper<QueryType, RestTuple>::value;
};
// ----------------------帮助函数,隐藏细节--------------------
template <typename QueryType, typename TypeTuple>
constexpr size_t get_type_index()
{
constexpr size_t index = tuple_element_index_helper<QueryType, TypeTuple>::value;
// 如果QueryType不在TypeTuple中,编译错误。
static_assert(index < std::tuple_size_v<TypeTuple>);
return index;
}
void type_index_in_typetuple_test()
{
static_assert(get_type_index<int, std::tuple<int>>() == 0);
// static_assert(get_type_index<int, std::tuple<double>>() == 0); // compile error
static_assert(get_type_index<double, std::tuple<double,int,double>>() == 0);
static_assert(get_type_index<double, std::tuple<int, double,double>>() == 1);
static_assert(get_type_index<double, std::tuple<double,int>>() == 0);
static_assert(get_type_index<int, std::tuple<double,int>>() == 1);
static_assert(get_type_index<bool, std::tuple<bool, double,int>>() == 0);
static_assert(get_type_index<double, std::tuple<bool, double,int>>() == 1);
static_assert(get_type_index<int, std::tuple<bool, double,int>>() == 2);
using StringTypeTestTuple = std::tuple<decltype("x"_t), decltype("xy"_t),decltype("y"_t),decltype("z"_t)>;
static_assert(get_type_index<decltype("x"_t), StringTypeTestTuple>() == 0);
static_assert(get_type_index<decltype("xy"_t), StringTypeTestTuple>() == 1);
static_assert(get_type_index<decltype("y"_t), StringTypeTestTuple>() == 2);
static_assert(get_type_index<decltype("z"_t), StringTypeTestTuple>() == 3);
// get_type_index<decltype("ttt"_t), StringTypeTestTuple>(); // compile error
}
//------------------------------------------------------------------------
template <typename ValueType, typename... Ts>
class NamedVectorWrapper
{
// 目前只支持算数类型或者枚举类型
static_assert(std::disjunction_v<std::is_arithmetic<ValueType>, std::is_enum<ValueType>>);
// Ts必须互补相同
static_assert(is_distinct_type_list<Ts...>::value);
public:
using NameTypeTuple = std::tuple<Ts...>;
explicit NamedVectorWrapper(ValueType* ptr): array_ptr_(ptr) {}
NamedVectorWrapper(const NamedVectorWrapper&) = default;
NamedVectorWrapper(NamedVectorWrapper&&) = default;
NamedVectorWrapper& operator=(const NamedVectorWrapper&) = default;
NamedVectorWrapper& operator=(NamedVectorWrapper&&) = default;
template <typename T>
ValueType operator[](const T&) const
{
const size_t index = get_type_index<T, NameTypeTuple>();
return array_ptr_[index];
}
template <typename T>
ValueType& operator[](const T&)
{
const size_t index = get_type_index<T, NameTypeTuple>();
return array_ptr_[index];
}
private:
ValueType* array_ptr_=nullptr;
};
int main(int argc, char** argv)
{
string_defined_type_test();
type_index_in_typetuple_test();
is_distinct_type_list_test();
using QuaternionWXYZWrapper = NamedVectorWrapper<double,
decltype("w"_t),
decltype("x"_t),
decltype("y"_t),
decltype("z"_t)>;
using ConstQuaternionWXYZWrapper = NamedVectorWrapper<const double,
decltype("w"_t),
decltype("x"_t),
decltype("y"_t),
decltype("z"_t)>;
double quaternion_array[4]={1.0,2.0,3.0,4.0};
QuaternionWXYZWrapper quat_wrapper(quaternion_array);
assert(quat_wrapper["w"_t] == 1.0);
assert(quat_wrapper["x"_t] == 2.0);
assert(quat_wrapper["y"_t] == 3.0);
assert(quat_wrapper["z"_t] == 4.0);
// assert(quat_wrapper["ww"_t]); // if given unknown param name, compile error.
quat_wrapper["w"_t] = 11.0;
quat_wrapper["x"_t] = 12.0;
quat_wrapper["y"_t] = 13.0;
quat_wrapper["z"_t] = 14.0;
assert(quaternion_array[0] == 11.0);
assert(quaternion_array[1] == 12.0);
assert(quaternion_array[2] == 13.0);
assert(quaternion_array[3] == 14.0);
const double const_quaternion_array[4]={1.0,2.0,3.0,4.0};
ConstQuaternionWXYZWrapper const_quat_wrapper(const_quaternion_array);
assert(const_quat_wrapper["x"_t] == 2.0);
// const_quat_wrapper["x"_t] = 2.0; compile error.
return EXIT_SUCCESS;
}