Tensorflow C++ API 生成复数算子

Tensorflow的complex 64和complex 128类型实际上是对std::complex的简单重定义。源码如下,
complex64和complex128

另外加入复数类型以后发现原来的打印函数不好用了。重新用模板函数特化实现了一遍,现在算是通用了。
程序结构如下,
程序代码结构
conanfile.txt

 [requires]
 gtest/1.10.0
 glog/0.4.0
 protobuf/3.9.1
 eigen/3.4.0
 dataframe/1.20.0
 opencv/3.4.17
 boost/1.76.0
 abseil/20210324.1
 xtensor/0.23.10
 zlib/1.2.11

 [generators]
 cmake

CMakeLists.txt

cmake_minimum_required(VERSION 3.3)


project(test_math_ops)

set(ENV{PKG_CONFIG_PATH} "$ENV{PKG_CONFIG_PATH}:/usr/local/lib/pkgconfig/")

set(CMAKE_CXX_STANDARD 17)
add_definitions(-g)

include(${CMAKE_BINARY_DIR}/conanbuildinfo.cmake)
conan_basic_setup()

find_package(TensorflowCC REQUIRED)
find_package(PkgConfig REQUIRED)
pkg_search_module(PKG_PARQUET REQUIRED IMPORTED_TARGET parquet)
pkg_search_module(PKG_ARROW REQUIRED IMPORTED_TARGET arrow)
pkg_search_module(PKG_ARROW_COMPUTE REQUIRED IMPORTED_TARGET arrow-compute)
pkg_search_module(PKG_ARROW_CSV REQUIRED IMPORTED_TARGET arrow-csv)
pkg_search_module(PKG_ARROW_DATASET REQUIRED IMPORTED_TARGET arrow-dataset)
pkg_search_module(PKG_ARROW_FS REQUIRED IMPORTED_TARGET arrow-filesystem)
pkg_search_module(PKG_ARROW_JSON REQUIRED IMPORTED_TARGET arrow-json)

set(ARROW_INCLUDE_DIRS ${PKG_PARQUET_INCLUDE_DIRS} ${PKG_ARROW_INCLUDE_DIRS} ${PKG_ARROW_COMPUTE_INCLUDE_DIRS} ${PKG_ARROW_CSV_INCLUDE_DIRS} ${PKG_ARROW_DATASET_INCLUDE_DIRS} ${PKG_ARROW_FS_INCLUDE_DIRS} ${PKG_ARROW_JSON_INCLUDE_DIRS})

set(INCLUDE_DIRS ${CMAKE_CURRENT_SOURCE_DIR}/../../include ${ARROW_INCLUDE_DIRS})

set(ARROW_LIBS PkgConfig::PKG_PARQUET PkgConfig::PKG_ARROW PkgConfig::PKG_ARROW_COMPUTE PkgConfig::PKG_ARROW_CSV PkgConfig::PKG_ARROW_DATASET PkgConfig::PKG_ARROW_FS PkgConfig::PKG_ARROW_JSON)

include_directories(${INCLUDE_DIRS})


file( GLOB test_file_list ${CMAKE_CURRENT_SOURCE_DIR}/*.cpp) 

file( GLOB APP_SOURCES ${CMAKE_CURRENT_SOURCE_DIR}/../../include/tf_/impl/tensor_testutil.cc ${CMAKE_CURRENT_SOURCE_DIR}/../../include/tf_/impl/queue_runner.cc ${CMAKE_CURRENT_SOURCE_DIR}/../../include/tf_/impl/coordinator.cc ${CMAKE_CURRENT_SOURCE_DIR}/../../include/tf_/impl/status.cc ${CMAKE_CURRENT_SOURCE_DIR}/../../include/death_handler/impl/*.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../../include/df/impl/*.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../../include/arr_/impl/*.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../../include/img_util/impl/*.cpp)

add_library(${PROJECT_NAME}_lib SHARED ${APP_SOURCES})
target_link_libraries(${PROJECT_NAME}_lib PUBLIC ${CONAN_LIBS} TensorflowCC::TensorflowCC ${ARROW_LIBS})

foreach( test_file ${test_file_list} )
    file(RELATIVE_PATH filename ${CMAKE_CURRENT_SOURCE_DIR} ${test_file})
    string(REPLACE ".cpp" "" file ${filename})
    add_executable(${file}  ${test_file})
    target_link_libraries(${file} PUBLIC ${PROJECT_NAME}_lib)
endforeach( test_file ${test_file_list})

tf_math2_test.cpp

TEST(TfArthimaticTests, Complex) {
    // Complex转换两个张量为一个复数张量
    // Refers to: https://tensorflow.google.cn/versions/r2.6/api_docs/cc/class/tensorflow/ops/complex?hl=zh-cn&authuser=0
    Scope root = Scope::NewRootScope();
    auto real_ = test::AsTensor<float>({2.25f, 3.25f}, {2});
    auto image_ = test::AsTensor<float>({4.75f, 5.75f}, {2});
    
    // 转换为复数
    auto complex_op = ops::Complex(root, real_, image_);

    ClientSession session(root);
    std::vector<Tensor> outputs;
    session.Run({complex_op.out}, &outputs);
    test::PrintTensorValue<complex64>(std::cout, outputs[0]);
    test::ExpectTensorEqual<complex64>(outputs[0], test::AsTensor<complex64>({
        {2.25f, 4.75f}, {3.25f, 5.75f}
    }, {2}));
}

PrintTensorValue函数

template <typename T>
std::ostream& PrintTensorValue(std::ostream& os, Tensor const& tensor) {
   // 打印Tensor值
    T const* tensor_pt = tensor.unaligned_flat<T>().data();
    auto size = tensor.NumElements();
    os << std::setprecision(std::numeric_limits<long double>::digits10 + 1);
    for(decltype(size) i=0; i<size; ++i) {
          os << tensor_pt[i] << "\n";
    }
    return os;
}

template <>
std::ostream& PrintTensorValue<uint8>(std::ostream& os, Tensor const& tensor) {
   // 打印Tensor值
    uint8 const* tensor_pt = tensor.unaligned_flat<uint8>().data();
    auto size = tensor.NumElements();
    os << std::setprecision(std::numeric_limits<long double>::digits10 + 1);
    for(decltype(size) i=0; i<size; ++i) {
          os << (int)tensor_pt[i] << "\n";
    }
    return os;
}

template <typename T>
std::ostream& PrintTensorValue(std::ostream& os, Tensor const& tensor, int per_line_count) {
   // 打印Tensor值
    T const* tensor_pt = tensor.unaligned_flat<T>().data();
    auto size = tensor.NumElements();
    os << std::setprecision(std::numeric_limits<long double>::digits10 + 1);
    for(decltype(size) i=0; i<size; ++i) {
        if(i!=0 && (i+1)%per_line_count == 0) {  
            os << tensor_pt[i] << "\n";
        }else {
            os << tensor_pt[i] << "\t";
        }
    }
    return os;
}


template <>
std::ostream& PrintTensorValue<uint8>(std::ostream& os, Tensor const& tensor, int per_line_count) {
   // 打印Tensor值
    uint8 const* tensor_pt = tensor.unaligned_flat<uint8>().data();
    auto size = tensor.NumElements();
    os << std::setprecision(std::numeric_limits<long double>::digits10 + 1);
    for(decltype(size) i=0; i<size; ++i) {
        if(i!=0 && (i+1)%per_line_count == 0) {  
            os << (int)tensor_pt[i] << "\n";
        }else {
            os << (int)tensor_pt[i] << "\t";
        }
    }
    return os;
}

程序输出如下,
复数API输出

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值