TVM Object类型系统

在TVM Object类型系统中最重要的是三个类:Object、ObjectPtr、ObjectRef

为什么需要这三个类?

设计目的:为了能够在不更改python前端的情况下扩展c++中的语言对象,且能够对任何语言对象序列化。

  • Object:编译器中所有的语言对象(命名一般以Node结尾)都是Object的子类,Object的子类保存了一般保存了数据成员变量
  • ObjectPtr:Object的子定义智能指针,用于进行内存管理
  • ObjectRef:Object的引用,ObjectRef的子类一般包含操作Object类的函数

代码实现

class TVM_DLL Object {
 protected:
  uint32_t type_index_{0};
  RefCounterType ref_counter_{0};
  FDeleter deleter_ = nullptr;

  inline void IncRef(); // 增加引用计数
  inline void DecRef(); // 减少引用计数

 private:
  inline int use_count() const; // 返回引用计数
};


template <typename T>
class ObjectPtr {
 public:
  // 默认构造函数
  ObjectPtr() {}
  ObjectPtr(std::nullptr_t) {}
  // 拷贝构造,会调用构造函数从而增加引用数
  ObjectPtr(const ObjectPtr<T>& other) : ObjectPtr(other.data_) {}
  template <typename U>
  ObjectPtr(const ObjectPtr<U>& other) : ObjectPtr(other.data_) {
    static_assert(std::is_base_of<T, U>::value,"can only assign of child class ObjectPtr to parent");
  }
  // 移动构造函数,不调用构造函数,引用数不变
  ObjectPtr(ObjectPtr<T>&& other) : data_(other.data_) { other.data_ = nullptr; }
  template <typename Y>
  ObjectPtr(ObjectPtr<Y>&& other) : data_(other.data_) {
    static_assert(std::is_base_of<T, Y>::value, "can only assign of child class ObjectPtr to parent");
    other.data_ = nullptr;
  }
  // 析构函数
  ~ObjectPtr() {
    if (data_ != nullptr) {
      data_->DecRef();
      data_ = nullptr;
  }
  // 使用计数
  int use_count() const { return data_ != nullptr ? data_->use_count() : 0; }
  // 访问成员变量
  T* get() const { return static_cast<T*>(data_); }
  T* operator->() const { return get(); }
  T& operator*() const { return *get(); }

 private:
  Object* data_{nullptr};
  // 构造函数,explicit意味着参数不能进行隐式转换
  explicit ObjectPtr(Object* data) : data_(data) {
    if (data != nullptr) {
      data_->IncRef();
    }
  }
};


class ObjectRef {
 public:
  const Object* get() const { return data_.get(); }
  const Object* operator->() const { return get(); }

  using ContainerType = Object;

 protected:
  ObjectPtr<Object> data_;
};

ObjectPtr类中有Object*类型的data_成员变量,可以通过->操作符和get()函数返回Object*指针。
ObjectRef类中有ObjectPtr<Object>类型的data_成员变量,也可以通过->操作符和get()函数返回Object*指针(调用了data_成员变量的get()函数)。

我们以StringObj、String类为例来展示功能:

class StringObj : public Object {
 public:
  const char* data;
  uint64_t size;

  static constexpr const uint32_t _type_index = TypeIndex::kRuntimeString;
  static constexpr const char* _type_key = "runtime.String";
  TVM_DECLARE_FINAL_OBJECT_INFO(StringObj, Object);

 private:
  class FromStd; // 内部类,用于从std::string初始化data和size
};

class StringObj::FromStd : public StringObj {
 public:
  explicit FromStd(std::string other) : data_container{other} {}

 private:
  std::string data_container;
};


class String : public ObjectRef {
 public:
  String() : String(std::string()) {}
  String(std::string other);
  String(const char* other)
      : String(std::string(other)) {}
  String(std::nullptr_t)
      : ObjectRef(nullptr) {}

  const char* c_str() const { return get()->data; }
  const char* data() const { return get()->data; }
  // 类型转换运算符重载
  operator std::string() const { return std::string{get()->data, size()}; }

  size_t size() const {
    const auto* ptr = get();
    return ptr->size;
  }
  size_t length() const { return size(); }
  bool empty() const { return size() == 0; }
  char at(size_t pos) const {...}

  TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(String, ObjectRef, StringObj);
};

StringObj类中定义了字符串的首字符的指针和字符串的长度,定义了类型键和类型索引,并使用了TVM_DECLARE_FINAL_OBJECT_INFO宏,定义如下:

#define TVM_DECLARE_FINAL_OBJECT_INFO(TypeName, ParentType) \
  static const constexpr bool _type_final = true;           \
  static const constexpr int _type_child_slots = 0;         \
  TVM_DECLARE_BASE_OBJECT_INFO(TypeName, ParentType)
  
#define TVM_DECLARE_BASE_OBJECT_INFO(TypeName, ParentType)                                     \
  static_assert(!ParentType::_type_final, "ParentObj marked as final");                        \
  static uint32_t RuntimeTypeIndex() {                                                         \
    static_assert(TypeName::_type_child_slots == 0 || ParentType::_type_child_slots == 0 ||    \
                      TypeName::_type_child_slots < ParentType::_type_child_slots,             \
                  "Need to set _type_child_slots when parent specifies it.");                  \
    if (TypeName::_type_index != ::tvm::runtime::TypeIndex::kDynamic) {                        \
      return TypeName::_type_index;                                                            \
    }                                                                                          \
    return _GetOrAllocRuntimeTypeIndex();                                                      \
  }                                                                                            \
  static uint32_t _GetOrAllocRuntimeTypeIndex() {                                              \
    static uint32_t tindex = Object::GetOrAllocRuntimeTypeIndex(                               \
        TypeName::_type_key, TypeName::_type_index, ParentType::_GetOrAllocRuntimeTypeIndex(), \
        TypeName::_type_child_slots, TypeName::_type_child_slots_can_overflow);                \
    return tindex;                                                                             \
  }

TVM_DECLARE_BASE_OBJECT_INFO宏定义了两个静态变量以及两个静态函数,RuntimeTypeIndex()用于获取运行时类型索引,在StringObj构造时type_index_成员变量的值通过该函数得到。

tvm::runtime::String类中定义了对StringObj进行操作的函数,如size(),并使用了TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS宏,定义如下:

#define TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(TypeName, ParentType, ObjectName)            \
  explicit TypeName(::tvm::runtime::ObjectPtr<::tvm::runtime::Object> n) : ParentType(n) {}    \
  TVM_DEFINE_DEFAULT_COPY_MOVE_AND_ASSIGN(TypeName);                                           \
  const ObjectName* operator->() const { return static_cast<const ObjectName*>(data_.get()); } \
  const ObjectName* get() const { return operator->(); }                                       \
  static constexpr bool _type_is_nullable = false;                                             \
  using ContainerType = ObjectName;

TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS宏定义了一个构造函数,重载了->运算符和get()函数(返回值类型不再是Object*而是StringObj*),标记该类型是不可为空的,定义了容器类型。

构造过程

构造函数String(std::string other);的实现如下:

inline String::String(std::string other) {
  auto ptr = make_object<StringObj::FromStd>(std::move(other));
  ptr->size = ptr->data_container.size();
  ptr->data = ptr->data_container.data();
  data_ = std::move(ptr);
}

可以看到在构造tvm::runtime::String时,首先使用分配器由std::string参数创建一个ObjectPtr<StringObj::FromStd>类型的对象。分配器创建完成后将StringObj::FromStddata_container的size和data直接赋值给StringObj的size和data,然后将ptr移动到tvm::runtime::Stringdata_


// include/tvm/runtime/memory.h
template <typename T, typename... Args>
inline ObjectPtr<T> make_object(Args&&... args) {
  return SimpleObjAllocator().make_object<T>(std::forward<Args>(args)...);
}

template <typename Derived>
class ObjAllocatorBase {
 public:
  template <typename T, typename... Args>
  inline ObjectPtr<T> make_object(Args&&... args) {
    using Handler = typename Derived::template Handler<T>;
    static_assert(std::is_base_of<Object, T>::value, "make can only be used to create Object");
    T* ptr = Handler::New(static_cast<Derived*>(this), std::forward<Args>(args)...);
    ptr->type_index_ = T::RuntimeTypeIndex();
    ptr->deleter_ = Handler::Deleter();
    return ObjectPtr<T>(ptr);
  }
}

使用分配器创建时把std::string移动到了data_container中,这时调用了StringObj的构造函数,然后设置了StringObjtype_index_deleter_,最后构造(通过ObjectPtr(Object* data))并返回了ObjectPtr,在构造时增加了Object的引用数。

总结

构造ObjectRef的过程中对Object进行了构造,主要使用Object保存数据,ObjectPtr用于内存管理,ObjectRef表示引用并对数据进行操作。

  • 23
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值