在TVM Object类型系统中最重要的是三个类:Object、ObjectPtr、ObjectRef
为什么需要这三个类?
设计目的:为了能够在不更改python前端的情况下扩展c++中的语言对象,且能够对任何语言对象序列化。
- Object:编译器中所有的语言对象(命名一般以Node结尾)都是Object的子类,Object的子类保存了一般保存了数据成员变量
- ObjectPtr:Object的子定义智能指针,用于进行内存管理
- ObjectRef:Object的引用,ObjectRef的子类一般包含操作Object类的函数
代码实现
class TVM_DLL Object {
protected:
uint32_t type_index_{0};
RefCounterType ref_counter_{0};
FDeleter deleter_ = nullptr;
inline void IncRef(); // 增加引用计数
inline void DecRef(); // 减少引用计数
private:
inline int use_count() const; // 返回引用计数
};
template <typename T>
class ObjectPtr {
public:
// 默认构造函数
ObjectPtr() {}
ObjectPtr(std::nullptr_t) {}
// 拷贝构造,会调用构造函数从而增加引用数
ObjectPtr(const ObjectPtr<T>& other) : ObjectPtr(other.data_) {}
template <typename U>
ObjectPtr(const ObjectPtr<U>& other) : ObjectPtr(other.data_) {
static_assert(std::is_base_of<T, U>::value,"can only assign of child class ObjectPtr to parent");
}
// 移动构造函数,不调用构造函数,引用数不变
ObjectPtr(ObjectPtr<T>&& other) : data_(other.data_) { other.data_ = nullptr; }
template <typename Y>
ObjectPtr(ObjectPtr<Y>&& other) : data_(other.data_) {
static_assert(std::is_base_of<T, Y>::value, "can only assign of child class ObjectPtr to parent");
other.data_ = nullptr;
}
// 析构函数
~ObjectPtr() {
if (data_ != nullptr) {
data_->DecRef();
data_ = nullptr;
}
// 使用计数
int use_count() const { return data_ != nullptr ? data_->use_count() : 0; }
// 访问成员变量
T* get() const { return static_cast<T*>(data_); }
T* operator->() const { return get(); }
T& operator*() const { return *get(); }
private:
Object* data_{nullptr};
// 构造函数,explicit意味着参数不能进行隐式转换
explicit ObjectPtr(Object* data) : data_(data) {
if (data != nullptr) {
data_->IncRef();
}
}
};
class ObjectRef {
public:
const Object* get() const { return data_.get(); }
const Object* operator->() const { return get(); }
using ContainerType = Object;
protected:
ObjectPtr<Object> data_;
};
ObjectPtr
类中有Object*
类型的data_
成员变量,可以通过->
操作符和get()
函数返回Object*
指针。
ObjectRef
类中有ObjectPtr<Object>
类型的data_
成员变量,也可以通过->
操作符和get()
函数返回Object*
指针(调用了data_
成员变量的get()
函数)。
我们以StringObj、String类为例来展示功能:
class StringObj : public Object {
public:
const char* data;
uint64_t size;
static constexpr const uint32_t _type_index = TypeIndex::kRuntimeString;
static constexpr const char* _type_key = "runtime.String";
TVM_DECLARE_FINAL_OBJECT_INFO(StringObj, Object);
private:
class FromStd; // 内部类,用于从std::string初始化data和size
};
class StringObj::FromStd : public StringObj {
public:
explicit FromStd(std::string other) : data_container{other} {}
private:
std::string data_container;
};
class String : public ObjectRef {
public:
String() : String(std::string()) {}
String(std::string other);
String(const char* other)
: String(std::string(other)) {}
String(std::nullptr_t)
: ObjectRef(nullptr) {}
const char* c_str() const { return get()->data; }
const char* data() const { return get()->data; }
// 类型转换运算符重载
operator std::string() const { return std::string{get()->data, size()}; }
size_t size() const {
const auto* ptr = get();
return ptr->size;
}
size_t length() const { return size(); }
bool empty() const { return size() == 0; }
char at(size_t pos) const {...}
TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(String, ObjectRef, StringObj);
};
StringObj
类中定义了字符串的首字符的指针和字符串的长度,定义了类型键和类型索引,并使用了TVM_DECLARE_FINAL_OBJECT_INFO
宏,定义如下:
#define TVM_DECLARE_FINAL_OBJECT_INFO(TypeName, ParentType) \
static const constexpr bool _type_final = true; \
static const constexpr int _type_child_slots = 0; \
TVM_DECLARE_BASE_OBJECT_INFO(TypeName, ParentType)
#define TVM_DECLARE_BASE_OBJECT_INFO(TypeName, ParentType) \
static_assert(!ParentType::_type_final, "ParentObj marked as final"); \
static uint32_t RuntimeTypeIndex() { \
static_assert(TypeName::_type_child_slots == 0 || ParentType::_type_child_slots == 0 || \
TypeName::_type_child_slots < ParentType::_type_child_slots, \
"Need to set _type_child_slots when parent specifies it."); \
if (TypeName::_type_index != ::tvm::runtime::TypeIndex::kDynamic) { \
return TypeName::_type_index; \
} \
return _GetOrAllocRuntimeTypeIndex(); \
} \
static uint32_t _GetOrAllocRuntimeTypeIndex() { \
static uint32_t tindex = Object::GetOrAllocRuntimeTypeIndex( \
TypeName::_type_key, TypeName::_type_index, ParentType::_GetOrAllocRuntimeTypeIndex(), \
TypeName::_type_child_slots, TypeName::_type_child_slots_can_overflow); \
return tindex; \
}
TVM_DECLARE_BASE_OBJECT_INFO
宏定义了两个静态变量以及两个静态函数,RuntimeTypeIndex()
用于获取运行时类型索引,在StringObj
构造时type_index_
成员变量的值通过该函数得到。
tvm::runtime::String
类中定义了对StringObj
进行操作的函数,如size()
,并使用了TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS
宏,定义如下:
#define TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(TypeName, ParentType, ObjectName) \
explicit TypeName(::tvm::runtime::ObjectPtr<::tvm::runtime::Object> n) : ParentType(n) {} \
TVM_DEFINE_DEFAULT_COPY_MOVE_AND_ASSIGN(TypeName); \
const ObjectName* operator->() const { return static_cast<const ObjectName*>(data_.get()); } \
const ObjectName* get() const { return operator->(); } \
static constexpr bool _type_is_nullable = false; \
using ContainerType = ObjectName;
TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS
宏定义了一个构造函数,重载了->
运算符和get()
函数(返回值类型不再是Object*
而是StringObj*
),标记该类型是不可为空的,定义了容器类型。
构造过程
构造函数String(std::string other);
的实现如下:
inline String::String(std::string other) {
auto ptr = make_object<StringObj::FromStd>(std::move(other));
ptr->size = ptr->data_container.size();
ptr->data = ptr->data_container.data();
data_ = std::move(ptr);
}
可以看到在构造tvm::runtime::String
时,首先使用分配器由std::string
参数创建一个ObjectPtr<StringObj::FromStd>
类型的对象。分配器创建完成后将StringObj::FromStd
的data_container
的size和data直接赋值给StringObj
的size和data,然后将ptr移动到tvm::runtime::String
的data_
。
// include/tvm/runtime/memory.h
template <typename T, typename... Args>
inline ObjectPtr<T> make_object(Args&&... args) {
return SimpleObjAllocator().make_object<T>(std::forward<Args>(args)...);
}
template <typename Derived>
class ObjAllocatorBase {
public:
template <typename T, typename... Args>
inline ObjectPtr<T> make_object(Args&&... args) {
using Handler = typename Derived::template Handler<T>;
static_assert(std::is_base_of<Object, T>::value, "make can only be used to create Object");
T* ptr = Handler::New(static_cast<Derived*>(this), std::forward<Args>(args)...);
ptr->type_index_ = T::RuntimeTypeIndex();
ptr->deleter_ = Handler::Deleter();
return ObjectPtr<T>(ptr);
}
}
使用分配器创建时把std::string
移动到了data_container
中,这时调用了StringObj
的构造函数,然后设置了StringObj
的type_index_
和deleter_
,最后构造(通过ObjectPtr(Object* data)
)并返回了ObjectPtr
,在构造时增加了Object
的引用数。
总结
构造ObjectRef的过程中对Object进行了构造,主要使用Object保存数据,ObjectPtr用于内存管理,ObjectRef表示引用并对数据进行操作。