Tensorflow的complex 64和complex 128类型实际上是对std::complex的简单重定义。源码如下,
另外加入复数类型以后发现原来的打印函数不好用了。重新用模板函数特化实现了一遍,现在算是通用了。
程序结构如下,
conanfile.txt
[requires]
gtest/1.10.0
glog/0.4.0
protobuf/3.9.1
eigen/3.4.0
dataframe/1.20.0
opencv/3.4.17
boost/1.76.0
abseil/20210324.1
xtensor/0.23.10
zlib/1.2.11
[generators]
cmake
CMakeLists.txt
cmake_minimum_required(VERSION 3.3)
project(test_math_ops)
set(ENV{PKG_CONFIG_PATH} "$ENV{PKG_CONFIG_PATH}:/usr/local/lib/pkgconfig/")
set(CMAKE_CXX_STANDARD 17)
add_definitions(-g)
include(${CMAKE_BINARY_DIR}/conanbuildinfo.cmake)
conan_basic_setup()
find_package(TensorflowCC REQUIRED)
find_package(PkgConfig REQUIRED)
pkg_search_module(PKG_PARQUET REQUIRED IMPORTED_TARGET parquet)
pkg_search_module(PKG_ARROW REQUIRED IMPORTED_TARGET arrow)
pkg_search_module(PKG_ARROW_COMPUTE REQUIRED IMPORTED_TARGET arrow-compute)
pkg_search_module(PKG_ARROW_CSV REQUIRED IMPORTED_TARGET arrow-csv)
pkg_search_module(PKG_ARROW_DATASET REQUIRED IMPORTED_TARGET arrow-dataset)
pkg_search_module(PKG_ARROW_FS REQUIRED IMPORTED_TARGET arrow-filesystem)
pkg_search_module(PKG_ARROW_JSON REQUIRED IMPORTED_TARGET arrow-json)
set(ARROW_INCLUDE_DIRS ${PKG_PARQUET_INCLUDE_DIRS} ${PKG_ARROW_INCLUDE_DIRS} ${PKG_ARROW_COMPUTE_INCLUDE_DIRS} ${PKG_ARROW_CSV_INCLUDE_DIRS} ${PKG_ARROW_DATASET_INCLUDE_DIRS} ${PKG_ARROW_FS_INCLUDE_DIRS} ${PKG_ARROW_JSON_INCLUDE_DIRS})
set(INCLUDE_DIRS ${CMAKE_CURRENT_SOURCE_DIR}/../../include ${ARROW_INCLUDE_DIRS})
set(ARROW_LIBS PkgConfig::PKG_PARQUET PkgConfig::PKG_ARROW PkgConfig::PKG_ARROW_COMPUTE PkgConfig::PKG_ARROW_CSV PkgConfig::PKG_ARROW_DATASET PkgConfig::PKG_ARROW_FS PkgConfig::PKG_ARROW_JSON)
include_directories(${INCLUDE_DIRS})
file( GLOB test_file_list ${CMAKE_CURRENT_SOURCE_DIR}/*.cpp)
file( GLOB APP_SOURCES ${CMAKE_CURRENT_SOURCE_DIR}/../../include/tf_/impl/tensor_testutil.cc ${CMAKE_CURRENT_SOURCE_DIR}/../../include/tf_/impl/queue_runner.cc ${CMAKE_CURRENT_SOURCE_DIR}/../../include/tf_/impl/coordinator.cc ${CMAKE_CURRENT_SOURCE_DIR}/../../include/tf_/impl/status.cc ${CMAKE_CURRENT_SOURCE_DIR}/../../include/death_handler/impl/*.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../../include/df/impl/*.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../../include/arr_/impl/*.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../../include/img_util/impl/*.cpp)
add_library(${PROJECT_NAME}_lib SHARED ${APP_SOURCES})
target_link_libraries(${PROJECT_NAME}_lib PUBLIC ${CONAN_LIBS} TensorflowCC::TensorflowCC ${ARROW_LIBS})
foreach( test_file ${test_file_list} )
file(RELATIVE_PATH filename ${CMAKE_CURRENT_SOURCE_DIR} ${test_file})
string(REPLACE ".cpp" "" file ${filename})
add_executable(${file} ${test_file})
target_link_libraries(${file} PUBLIC ${PROJECT_NAME}_lib)
endforeach( test_file ${test_file_list})
tf_math2_test.cpp
TEST(TfArthimaticTests, Complex) {
// Complex转换两个张量为一个复数张量
// Refers to: https://tensorflow.google.cn/versions/r2.6/api_docs/cc/class/tensorflow/ops/complex?hl=zh-cn&authuser=0
Scope root = Scope::NewRootScope();
auto real_ = test::AsTensor<float>({2.25f, 3.25f}, {2});
auto image_ = test::AsTensor<float>({4.75f, 5.75f}, {2});
// 转换为复数
auto complex_op = ops::Complex(root, real_, image_);
ClientSession session(root);
std::vector<Tensor> outputs;
session.Run({complex_op.out}, &outputs);
test::PrintTensorValue<complex64>(std::cout, outputs[0]);
test::ExpectTensorEqual<complex64>(outputs[0], test::AsTensor<complex64>({
{2.25f, 4.75f}, {3.25f, 5.75f}
}, {2}));
}
PrintTensorValue函数
template <typename T>
std::ostream& PrintTensorValue(std::ostream& os, Tensor const& tensor) {
// 打印Tensor值
T const* tensor_pt = tensor.unaligned_flat<T>().data();
auto size = tensor.NumElements();
os << std::setprecision(std::numeric_limits<long double>::digits10 + 1);
for(decltype(size) i=0; i<size; ++i) {
os << tensor_pt[i] << "\n";
}
return os;
}
template <>
std::ostream& PrintTensorValue<uint8>(std::ostream& os, Tensor const& tensor) {
// 打印Tensor值
uint8 const* tensor_pt = tensor.unaligned_flat<uint8>().data();
auto size = tensor.NumElements();
os << std::setprecision(std::numeric_limits<long double>::digits10 + 1);
for(decltype(size) i=0; i<size; ++i) {
os << (int)tensor_pt[i] << "\n";
}
return os;
}
template <typename T>
std::ostream& PrintTensorValue(std::ostream& os, Tensor const& tensor, int per_line_count) {
// 打印Tensor值
T const* tensor_pt = tensor.unaligned_flat<T>().data();
auto size = tensor.NumElements();
os << std::setprecision(std::numeric_limits<long double>::digits10 + 1);
for(decltype(size) i=0; i<size; ++i) {
if(i!=0 && (i+1)%per_line_count == 0) {
os << tensor_pt[i] << "\n";
}else {
os << tensor_pt[i] << "\t";
}
}
return os;
}
template <>
std::ostream& PrintTensorValue<uint8>(std::ostream& os, Tensor const& tensor, int per_line_count) {
// 打印Tensor值
uint8 const* tensor_pt = tensor.unaligned_flat<uint8>().data();
auto size = tensor.NumElements();
os << std::setprecision(std::numeric_limits<long double>::digits10 + 1);
for(decltype(size) i=0; i<size; ++i) {
if(i!=0 && (i+1)%per_line_count == 0) {
os << (int)tensor_pt[i] << "\n";
}else {
os << (int)tensor_pt[i] << "\t";
}
}
return os;
}
程序输出如下,