按源码中注释来说,variant是一个类似C++ std::any的可以保存任意类型对象的容器。
源码在:tensorflow/core/framework/variant.h
tensorflow中的variant是为存储dtype=DT_VARIANT的Tensor而设计的。
使用限制
放入variant中的对象满足如下条件:
- 有拷贝构造函数
- 有默认构造函数
- 是protocol buffer, 或是tensor或者定义了
string TypeName() const;
void Encode(VariantTensorData* data) const;
bool Decode(VariantTensorData data);
如下是一个示例 :tensorflow/learn/test/variant_test.cpp
//#include <any> //https://en.cppreference.com/w/cpp/utility/any
//#include <variant> //https://en.cppreference.com/w/cpp/utility/variant
//c++17才有 any variant
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/framework/variant.h"
using namespace tensorflow;
struct MyVariantValue {
int x;
};
bool DecodeVariant(VariantTensorData* data, MyVariantValue* value)
{
return true;
}
bool DecodeVariant(std::string* buf, MyVariantValue* value)
{
return true;
}
void EncodeVariant(const MyVariantValue& value, VariantTensorData* data)
{
}
void EncodeVariant(const MyVariantValue& value, std::string* buf)
{
}
string DebugStringVariant(const MyVariantValue &value)
{
return "debug my variant value";
}
string TypeNameVariant(const MyVariantValue &value)
{
return "MyVariant";
}
TEST(VariantTest, Variant)
{
LOG(INFO) << "test tensorflow core framework variant";
MyVariantValue t;
t.x = 3;
Variant x = t;
LOG(INFO) << x.get<MyVariantValue>()->x;
EXPECT_EQ(x.get<MyVariantValue>()->x, 3);
}
//bazel test --test_output all -c opt //tensorflow/learn:test_variant_test
BUILD文件在tensorflow/learn下
load("//tensorflow:tensorflow.bzl", "tf_custom_op_library")
load(
"//tensorflow:tensorflow.bzl",
"tf_cc_test",
"tf_cc_tests",
"tf_copts",
"tf_cuda_library",
)
load(
"//tensorflow/core/platform:build_config.bzl",
"tf_kernel_tests_linkstatic",
"tf_proto_library",
"tf_pyclif_proto_library",
)
tf_cc_tests(
name = "learn_test",
size = "small",
srcs = [
"test/variant_test.cpp"
],
linkopts = select({
"//tensorflow:macos": ["-headerpad_max_install_names"],
"//conditions:default": [],
}),
linkstatic = tf_kernel_tests_linkstatic(),
visibility = [
"//tensorflow:internal",
"//tensorflow/core:__pkg__",
],
deps = [
"//tensorflow/cc:cc_ops",
"//tensorflow/cc:cc_ops_internal",
"//tensorflow/cc:function_ops",
"//tensorflow/cc:ops",
"//tensorflow/cc:scope",
"//tensorflow/cc:sendrecv_ops",
"//tensorflow/cc:while_loop",
"//tensorflow/compiler/tf2xla/ops:xla_ops",
"//tensorflow/core",
"//tensorflow/core:core_cpu",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework",
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:ops",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
"//tensorflow/core/common_runtime:direct_session_internal",
"//tensorflow/core/kernels:ops_util",
"//tensorflow/core/lib/monitoring:cell_reader",
"//tensorflow/core/platform:regexp",
"//tensorflow/core/platform:status_matchers",
"//tensorflow/core/profiler/lib:profiler_session",
"//tensorflow/core/profiler/protobuf:memory_profile_proto_cc",
"//tensorflow/core/profiler/utils:xplane_schema",
"//tensorflow/core/profiler/utils:xplane_visitor",
"//tensorflow/core/util:protos_test_cc",
"//third_party/eigen3",
"@com_google_absl//absl/base",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
],
)
运行测试
bazel test --test_output all -c opt //tensorflow/learn:test_variant_test