mxnet代码解析之mshadow

mshadow采用了表达式模板的技巧增强了c++矩阵库的性能。
mshadow用于数据存储结构的主要继承脉络如下:
Tensor->TRValue->RValueExp->Exp
继承链的顶端是所有表达式的基类Exp:

template<typename SubType, typename DType, int exp_type>
struct Exp {
 public:
  /*! \return  subtype instance of current class */
  inline const SubType& self(void) const {
    return *static_cast<const SubType*>(this);
  }
  /*! \return reference of subtype instance of current class */
  inline SubType* ptrself(void) {
    return static_cast<SubType*>(this);
  }
};

这里Exp定义的精髓就是通过self或ptrself可以获得SubType的引用或指针,这为后面将SubType表达式作为模板参数传递后再获得SubType提供了途径。
RValueExp继承于Exp,是所有右值的基类:

template<typename Container, typename DType>
class RValueExp: public Exp<Container, DType, type::kRValue> {
 public:
  inline const TransposeExp<Container, DType> T(void) const {
    return TransposeExp<Container, DType>(this->self());
  }
  /*! \brief operator overload */
  inline Container &operator+=(DType s) {
    ExpEngine<sv::plusto, Container, DType>::Eval(this->ptrself(), scalar<DType>(s));
    return *(this->ptrself());
  }
  /*! \brief operator overload */
  inline Container &operator-=(DType s) {......}
  /*! \brief operator overload */
  inline Container &operator*=(DType s) {......}
  /*! \brief operator overload */
  inline Container &operator/=(DType s) {......}

  /*! \brief operator overload */
  inline Container &__assign(DType s) {
    ExpEngine<sv::saveto, Container, DType>::Eval(this->ptrself(), scalar<DType>(s));
    return *(this->ptrself());
  }
  /*! \brief  we can not define container = container */
  template<typename E, int etype>
  inline Container &__assign(const Exp<E, DType, etype> &exp) {
    ExpEngine<sv::saveto, Container, DType>::Eval(this->ptrself(), exp.self());
    return *(this->ptrself());
  }
  /*! \brief operator overload, assign */
  inline Container &__assign(const Exp<Container, DType, type::kRValue> &exp);

  /*! \brief implementation of operator+= */
  template<typename E, int etype>
  inline Container &operator+=(const Exp<E, DType, etype> &exp) {
    ExpEngine<sv::plusto, Container, DType>::Eval(this->ptrself(), exp.self());
    return *(this->ptrself());
  }
  /*! \brief implementation of operator-= */
  template<typename E, int etype>
  inline Container &operator-=(const Exp<E, DType, etype> &exp) {......}
  /*! \brief implementation of operator*= */
  template<typename E, int etype>
  inline Container &operator*=(const Exp<E, DType, etype> &exp) {......}
  /*! \brief implementation of operator/= */
  template<typename E, int etype>
  inline Container &operator/=(const Exp<E, DType, etype> &exp) {......}
};

RValueExp类中的ExpEngine类重载了四种表达式类型(kMapper、kChainer、kRValue、kComplex)的Eval函数 :

template<typename SV, typename RV, typename DType>
struct ExpEngine {
  template<typename E>
  inline static void Eval(RV *dst,const Exp<E, DType, type::kMapper> &exp) {
    MapExp<SV>(dst, exp);
  }
  template<typename E>
  inline static void Eval(RV *dst,const Exp<E, DType, type::kChainer> &exp) {
    MapExp<SV>(dst, exp);
  }
  template<typename E>
  inline static void Eval(RV *dst,const Exp<E, DType, type::kRValue> &exp) {
    MapExp<SV>(dst, exp);
  }
  //用于dot
  template<typename E>
  inline static void Eval(RV *dst,const Exp<E, DType, type::kComplex> &exp) {
    ExpComplexEngine<SV, RV, E, DType>::Eval(dst->ptrself(), exp.self());
  }
};

TRValue是所有可能的tensor的超类

template<typename Container, typename Device, int dimension, typename DType>
struct TRValue: public expr::RValueExp<Container, DType> {
};

终于到了Tensor类的定义:

template<typename Device, int dimension,
         typename DType MSHADOW_DEFAULT_DTYPE>
struct Tensor: public TRValue<Tensor<Device, dimension, DType>,
                              Device, dimension, DType> {
 public:
  static const bool kDevCPU = Device::kDevCPU;
  static const int  kSubdim = dimension - 1;

  /*! \brief pointer to the data */
  DType *dptr_;
  /*! \brief shape of the tensor */
  Shape<dimension> shape_;
  /*!
   * \brief storing the stride information in x dimension
   *    this is used to deal with pitch allocation in gpu or sse(align x dimension to 64bit) for efficiency
   */
  index_t stride_;

  Stream<Device> *stream_;

  //各种构造函数
  ......
  /*! 从 data pointer 和shape构造Tensor  */
  MSHADOW_XINLINE Tensor(DType *dptr,
                         const Shape<dimension> &shape,
                         index_t stride, Stream<Device> *stream)
      : dptr_(dptr), shape_(shape), stride_(stride), stream_(stream) {}
  ......
  ......

  MSHADOW_XINLINE Tensor<Device, dimension, DType>
  Slice(index_t begin, index_t end) const {
    Shape<dimension> s = this->shape_;
    s[0] = end - begin;
    return Tensor<Device, dimension, DType>(dptr_ + this->MemSize<1>() * begin,
                                            s, stride_, stream_);
  }
  /*!\brief implement the assignment of same type */
  inline Tensor<Device, dimension, DType> &
  operator=(const Tensor<Device, dimension, DType> &exp) {
    dptr_ = exp.dptr_;
    shape_ = exp.shape_;
    stride_ = exp.stride_;
    stream_ = exp.stream_;
    return *this;
  }
  /*!\brief functions to fit expression template */
  template<typename E, int etype>
  inline Tensor<Device, dimension, DType> &
  operator=(const expr::Exp<E, DType, etype> &exp) {
    return this->__assign(exp);
  }
  /*!\brief functions to fit expression template */
  inline Tensor<Device, dimension, DType> &operator=(const DType &exp) {
    return this->__assign(exp);
  }
};

Tensor的shape与numpy.shape不一样,最低维度从shape_[0]开始,重载操作符“=”除了拷贝已有Tensor,还可赋值中间运算结果表达式Exp,以及赋值标量。这里对operator =的重载将运算操作延迟到了赋值阶段,实现了Lazy Evaluation,避免了临时内存分配。特别地,DotExp在operator =中实行lazily evaluate, 将矩阵的乘法重定向到了blas库。

mshadow用于表达式操作的类(DotExp、BinaryMapExp、UnaryMapExp)同样继承于Exp基类,其特点是该表达式操作类自身也作为模板参数传递给Exp,以BinaryMapExp为例:

template<typename OP, typename TA, typename TB, typename DType, int etype>
struct BinaryMapExp: public Exp<BinaryMapExp<OP, TA, TB, DType, etype>,
                                DType, etype> {
  /*! \brief left operand */
  const TA &lhs_;
  /*! \brief right operand */
  const TB &rhs_;
  /*! \brief constructor */
  explicit BinaryMapExp(const TA &lhs, const TB &rhs)
      :lhs_(lhs), rhs_(rhs) {}
};

template<typename OP, typename TA, typename TB, typename DType, int ta, int tb>
inline BinaryMapExp<OP, TA, TB, DType, (ta|tb|type::kMapper)>
MakeExp(const Exp<TA, DType, ta> &lhs, const Exp<TB, DType, tb> &rhs) {
  return BinaryMapExp<OP, TA, TB, DType,
                      (ta|tb|type::kMapper)>(lhs.self(), rhs.self());
}

template<typename OP, typename TA, typename TB, typename DType, int ta, int tb>
inline BinaryMapExp<OP, TA, TB, DType, (ta|tb|type::kMapper)>
F(const Exp<TA, DType, ta> &lhs, const Exp<TB, DType, tb> &rhs) {
  return MakeExp<OP>(lhs, rhs);
}

template<typename TA, typename TB, typename DType, int ta, int tb>
inline BinaryMapExp<op::plus, TA, TB, DType, (ta|tb|type::kMapper)>
operator+(const Exp<TA, DType, ta> &lhs, const Exp<TB, DType, tb> &rhs) {
  return MakeExp<op::plus>(lhs, rhs);
}
......
......

BinaryMapExp是双目运算的表达式类,MakeExp是用来生成BinaryMapExp类对象的函数,F是自定义操作的函数,F< OP >(lhs, rhs)描述了一个新的双目运算,除此以外,+-*/等操作符重载函数也调用MakeExp创建BinaryMapExp。
这些用于表达式操作的类(DotExp、BinaryMapExp、UnaryMapExp)表示一个运算操作的中间结果,且可以递归表示(由Plan的Eval函数完成),实现了lengthy equations的解析。

真正用于递归调用eval的是Plan类:

template<typename ExpType, typename DType>
class Plan {
 public:
  /*!
   * \brief evaluate the expression at index [y][x]
   *  to be implemented by SubType, for RValue, the return type will be DType &
   */
  MSHADOW_XINLINE DType Eval(index_t y, index_t x) const;
};

// tensor的plan函数
template <typename Device, int dim, typename DType>
class Plan<Tensor<Device, dim, DType>, DType> {
 public:
  explicit Plan(const Tensor<Device, dim, DType> &t)
      : dptr_(t.dptr_), stride_(t.stride_) {}
  // for RValue, the return type should be reference
  MSHADOW_XINLINE DType &REval(index_t y, index_t x) {
    return dptr_[y * stride_ + x];
  }
  // const evaluation
  MSHADOW_XINLINE const DType &Eval(index_t y, index_t x) const {
    return dptr_[y * stride_ + x];
  }

 private:
  DType  *dptr_;
  index_t stride_;
};
......
......
// 双目表达式的plan
template<typename OP, typename TA, typename TB, int etype, typename DType>
class Plan<BinaryMapExp<OP, TA, TB, DType, etype>, DType> {
 public:
  explicit Plan(const Plan<TA, DType> &lhs, const Plan<TB, DType> &rhs)
      : lhs_(lhs), rhs_(rhs) {}
  MSHADOW_XINLINE DType Eval(index_t y, index_t x) const {
    return OP::Map(lhs_.Eval(y, x), rhs_.Eval(y, x));
  }

 private:
  Plan<TA, DType> lhs_;
  Plan<TB, DType> rhs_;
};

调用模板函数MakePlan从表达式生成Plan,再以BinaryMapExp表达式为例:

template<typename T, typename DType>
inline Plan<T, DType> MakePlan(const RValueExp<T, DType> &e) {
  return Plan<T, DType>(e.self());
}

template<typename OP, typename TA, typename TB, typename DType, int etype>
inline Plan<BinaryMapExp<OP, TA, TB, DType, etype>, DType>
MakePlan(const BinaryMapExp<OP, TA, TB, DType, etype> &e) {
  return Plan<BinaryMapExp<OP, TA, TB, DType, etype>,
              DType>(MakePlan(e.lhs_), MakePlan(e.rhs_));
}

已知ExpEngine中重载的eval中调用了MapExp来实现将Plan的结果传递给目标Exp的作用,MapExp又通过直接或间接(为处理sse优化而加入一个中间类MapExpCPUEngine)调用MapPlan实现上述作用。

Tensor的维数在定义后就固定了,因此在图模型中需要一个更为抽象灵活的数据结构,这就是TBlob:

class TBlob {
 public:
  /*! \brief pointer to the data */
  void *dptr_;
  /*! \brief shape of the tensor */
  TShape shape_;
  /*!
   * \brief storing the stride information in x dimension
   */
  index_t stride_;
  /*! \brief device mask of the corresponding device */
  int dev_mask_;
  /*! \brief type flag of the tensor blob */
  int type_flag_;
  ......
  template<typename Device, int dim, typename DType>
  inline Tensor<Device, dim, DType> get(Stream<Device> *stream = NULL) const {...}

  template<typename Device, int dim, typename DType>
  inline Tensor<Device, dim, DType> get_with_shape(const Shape<dim> &shape,
                                                   Stream<Device> *stream = NULL) const
  {
    CHECK(Device::kDevMask == dev_mask_)
      << "TBlob.get: device type do not match specified type";
    CHECK(DataType<DType>::kFlag == type_flag_)
      << "TBlob.get_with_shape: data type do not match specified type."
      << "Expected: " << type_flag_ << " v.s. given " << DataType<DType>::kFlag;
    CHECK_EQ(this->CheckContiguous(), true) << "TBlob.get_reshape: must be contiguous";
    CHECK_EQ(this->shape_.Size(), shape.Size())
      << "TBlob.get_with_shape: new and old shape do not match total elements";
    return Tensor<Device, dim, DType>(static_cast<DType*>(dptr_),
                                      shape,
                                      shape[dim - 1],
                                      stream);
  }
  ......
  template<typename Device, typename DType>
  inline Tensor<Device, 2, DType> FlatTo2D(Stream<Device> *stream = NULL) const {}
  ......
  template<typename Device, typename DType>
  inline Tensor<Device, 3, DType> FlatTo3D(int axis, Stream<Device> *stream = NULL)  
  const {}
  ......                           
}

TBlob不涉及任何算数运算,也没有隐式的内存分配与释放,它就像一个指针类,在需要的时候调用get、get_with_shape、FlatTo2D、FlatTo3D等获得固定维数的Tensor来做更多的操作。Tshape与TBlob类似,在需要的时候调用get、FlatTo2D等获得Tensor对应的Shape。

  • 3
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值