c10::IValue
是libtorch中的基础数据类型,很常见,之前一直用的云里雾里,在此整理一下
0. 简介
c10::IValue
像一个数据容器,但是它又不用来直接存储数据,只是一层数据的封装。
怎么理解呢?c10::IValue
可以存储libtorch里很多类型的数据,比如c10::IValue
存储可能是一个Tensor
,一组Tensor
,或者是一个Moudle
,甚至是一个int
,所以c10::IValue
更像是一种封装,对不同的数据类型进行了一次统一的封装,然后很多很多函数的接口都可以使用这种统一的数据类型了。
如果你用过opencv
,那么你可以觉得眼熟,cv::InputArray
,cv::OutputArray
不就是这么干的么,比如常用的cv::resize
函数,它的输入、输出数据就是cv::InputArray
,cv::OutputArray
,而不是直接使用cv::Mat
,这其实就是一种封装的思想。
void resize( InputArray src, OutputArray dst,
Size dsize, double fx = 0, double fy = 0,
int interpolation = INTER_LINEAR );
1. 类的构造
先看一下c10::IValue
的定义:
class c10::IValue {
Payload payload;
Tag tag;
bool is_intrusive_ptr;
}
union Payload {
// We use a nested union here so that we can make the copy easy
// and efficient in the non-tensor (i.e., trivially copyable)
// case. Specifically, we do not have to do a switch-on-tag to
// figure out which union member to assign; we can just use
// TriviallyCopyablePayload::operator=.
union TriviallyCopyablePayload {
TriviallyCopyablePayload() : as_int(0) {}
int64_t as_int;
double as_double;
bool as_bool;
// Invariant: never nullptr; null state is represented as
// c10::UndefinedTensorImpl::singleton() for consistency of
// representation with Tensor.
c10::intrusive_ptr_target* as_intrusive_ptr;
struct {
DeviceType type;
DeviceIndex index;
} as_device;
} u;
at::Tensor as_tensor;
Payload() : u() {}
~Payload() {}
};
c10::IValue
只有3个成员变量,一个用于存储数据的payload
,一个表示数据类型的tag
,还有一个指示是不是others类型的is_intrusive_ptr
,当然,还有很多很多成员函数,详情看这里,或者..../libtorch/include/ATen/core/ivalue.h
文件
Payload payload
:c10::Payload
是一个union
类型,c10::IValues
在IValue::Payload
中包含这些数据的值,它将基本类型(int64_t, bool, double, Device)
和Tensor
作为值,并将所有其他类型保存在c10::intrusive_ptr_target指针里边
。Tag tag
:c10::Tag
是一个enum
类型,表示c10::IValue
里保存的是什么类型数据,可以支持下面这些类型
#define TORCH_FORALL_TAGS(_) \
_(None) \
_(Tensor) \
_(Storage) \
_(Double) \
_(ComplexDouble) \
_(Int) \
_(Bool) \
_(Tuple) \
_(String) \
_(Blob) \
_(GenericList) \
_(GenericDict) \
_(Future) \
_(Device) \
_(Stream) \
_(Object) \
_(PyObject) \
_(Uninitialized) \
_(Capsule) \
_(RRef) \
_(Quantizer) \
_(Generator) \
_(Enum)
bool is_intrusive_ptr
:一个bool值,是否为intrusive class
,这个intrusive class
是啥意思我也没太理解,大概可能就是非 [基本类型(int64_t, bool, double, Device)
和Tensor
],其他都是intrusive class
,比如Tuple
,String
之类的。如果为True的话,就得去c10::intrusive_ptr_target
指针里读取数据了。
2. 用法
c10::IValue
最主要的用法应该就是把数据取出来了,这一点从c10::IValue
的成员函数也能看出来,一大半函数都是isXXX,toXXX之类的,转化为其他类型
简单说几个用法:
- 判断
c10::IValue
里边存储的什么类型
c10::IValue a = torch::ones({1, 3, 640, 640});
auto b = a.type().get()->kind();
auto c = c10::typeKindToString(b);
std::cout << c << std::endl;
- 获取数据
使用c10::IValue::toXXX()
函数
torch::Tensor t = ivalue.toTensor(); //TensorType
bool t = ivalue.toBool(); //BoolType
auto t = ivalue.toList(); //ListType