方案一:
通过std::is_base_of_v函数来判断
#include <iostream>
#include <type_traits>
// 定义类型特征模板 is_base_or_derived,用于判断 T 是否是 BaseType 或其派生类
template <typename T, typename BaseType>
struct is_base_or_derived {
static constexpr bool value = std::is_base_of_v<BaseType, T>;
//static constexpr bool value = std::is_base_of<BaseType, T>::value;
};
// 定义模板函数,限定输入参数类型为 BaseType 或其派生类,否则编译报错
// 定义一个基类和一个派生类
class Base {};
class Derived : public Base {};
template <typename T>
void foo(T arg) {
static_assert(is_base_or_derived<T, Base>::value, "T must be derived from Base");
// 函数实现...
}
class C {};
int main() {
Base b;
Derived d;
C c;
// 调用 foo 函数,传入一个派生类对象
foo(d); // 编译通过
foo(b); // 编译通过
//foo(c); // 编译报错
return 0;
}
方案二:
使用std::enable_if
#include <iostream>
#include <type_traits>
class Base{};
class A{};
class B:public Base
{};
template<typename T, typename = typename std::enable_if<std::is_base_of_v<Base, T>, void>::type>
void func(T& t)
{
std::cout<<"hello world"<<std::endl;
}
int main() {
Base base;
B b;
A a;
func(base);
func(b);
//func(a); // 编译报错
return 0;
}