libtorch c10::IValue类解析,从IValue获取值

c10::IValue是libtorch中的基础数据类型,很常见,之前一直用的云里雾里,在此整理一下

0. 简介

c10::IValue像一个数据容器,但是它又不用来直接存储数据,只是一层数据的封装。

怎么理解呢?c10::IValue可以存储libtorch里很多类型的数据,比如c10::IValue存储可能是一个Tensor,一组Tensor,或者是一个Moudle,甚至是一个int,所以c10::IValue更像是一种封装,对不同的数据类型进行了一次统一的封装,然后很多很多函数的接口都可以使用这种统一的数据类型了。

如果你用过opencv,那么你可以觉得眼熟,cv::InputArraycv::OutputArray不就是这么干的么,比如常用的cv::resize函数,它的输入、输出数据就是cv::InputArraycv::OutputArray,而不是直接使用cv::Mat,这其实就是一种封装的思想。

void resize( InputArray src, OutputArray dst,
             Size dsize, double fx = 0, double fy = 0,
             int interpolation = INTER_LINEAR );

1. 类的构造

先看一下c10::IValue的定义:

class c10::IValue {
  Payload payload;
  Tag tag;
  bool is_intrusive_ptr;
  }

union Payload {
    // We use a nested union here so that we can make the copy easy
    // and efficient in the non-tensor (i.e., trivially copyable)
    // case. Specifically, we do not have to do a switch-on-tag to
    // figure out which union member to assign; we can just use
    // TriviallyCopyablePayload::operator=.
    union TriviallyCopyablePayload {
      TriviallyCopyablePayload() : as_int(0) {}
      int64_t as_int;
      double as_double;
      bool as_bool;
      // Invariant: never nullptr; null state is represented as
      // c10::UndefinedTensorImpl::singleton() for consistency of
      // representation with Tensor.
      c10::intrusive_ptr_target* as_intrusive_ptr;
      struct {
        DeviceType type;
        DeviceIndex index;
      } as_device;
    } u;
    at::Tensor as_tensor;
    Payload() : u() {}
    ~Payload() {}
  };

c10::IValue只有3个成员变量,一个用于存储数据的payload,一个表示数据类型的tag,还有一个指示是不是others类型的is_intrusive_ptr,当然,还有很多很多成员函数,详情看这里,或者..../libtorch/include/ATen/core/ivalue.h文件

  1. Payload payloadc10::Payload是一个union类型,c10::IValuesIValue::Payload中包含这些数据的值,它将基本类型(int64_t, bool, double, Device)Tensor作为值,并将所有其他类型保存在c10::intrusive_ptr_target指针里边
  2. Tag tagc10::Tag是一个enum类型,表示c10::IValue里保存的是什么类型数据,可以支持下面这些类型
#define TORCH_FORALL_TAGS(_) \
  _(None)                    \
  _(Tensor)                  \
  _(Storage)                 \
  _(Double)                  \
  _(ComplexDouble)           \
  _(Int)                     \
  _(Bool)                    \
  _(Tuple)                   \
  _(String)                  \
  _(Blob)                    \
  _(GenericList)             \
  _(GenericDict)             \
  _(Future)                  \
  _(Device)                  \
  _(Stream)                  \
  _(Object)                  \
  _(PyObject)                \
  _(Uninitialized)           \
  _(Capsule)                 \
  _(RRef)                    \
  _(Quantizer)               \
  _(Generator)               \
  _(Enum)
  1. bool is_intrusive_ptr:一个bool值,是否为intrusive class,这个intrusive class是啥意思我也没太理解,大概可能就是非 [基本类型(int64_t, bool, double, Device)Tensor],其他都是intrusive class,比如TupleString之类的。如果为True的话,就得去c10::intrusive_ptr_target指针里读取数据了。

2. 用法

c10::IValue最主要的用法应该就是把数据取出来了,这一点从c10::IValue的成员函数也能看出来,一大半函数都是isXXX,toXXX之类的,转化为其他类型

简单说几个用法:

  1. 判断c10::IValue里边存储的什么类型
c10::IValue a = torch::ones({1, 3, 640, 640});
auto b = a.type().get()->kind();
auto c = c10::typeKindToString(b);
std::cout << c << std::endl;
  1. 获取数据
    使用c10::IValue::toXXX()函数
torch::Tensor t = ivalue.toTensor();  //TensorType
bool t = ivalue.toBool();  //BoolType
auto t = ivalue.toList();  //ListType
  • 4
    点赞
  • 8
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值