LibTorch:常用API总结和验证


参考 https://www.cnblogs.com/yanghailin/p/12901586.html ,对自己用到的API进行总结和验证


Installation


CMakeLists.txt

cmake_minimum_required(VERSION 3.5)

set(PROJ demo)
project(${PROJ})
set(CMAKE_CXX_STANDARD 14)

set(OpenCV_DIR "../opencv/build")
find_package(OpenCV 4 REQUIRED)

set(Torch_DIR "../libtorch/share/cmake/Torch")
find_package(Torch REQUIRED)

add_subdirectory(./some_modules)

add_executable(${PROJ} main.cpp)
target_link_libraries(${PROJ} ${OpenCV_LIBS})
target_link_libraries(${PROJ} ${TORCH_LIBRARIES})

尽量选择版本3.5左右的CMake,避免警告。


输出调试

我这里使用的是CLion调用VS进行编译,所以不支持打断点调试LibTorch,但是一般打印输出所需要的信息就能解决大部分问题。

  • 输出Tensor 变量: 适合小数据

    auto mat = torch::rand({3,4});
    cout << mat << endl;  
    
    // 输出:
    // 0.5601  0.1729  0.1988  0.2926
    // 0.8399  0.8283  0.7591  0.7151
    // 0.8795  0.6665  0.9448  0.0492
    //[ CPUFloatType{3,4} ]
    
  • 输出Tensor Sizes:

    auto var = torch::rand({1, 3, 224, 224});
    cout << var.sizes() << endl;   
    
    // 输出:
    // [1, 3, 224, 224]
    
  • 输出Tensor 切片: var[0, 0, 0:5, 0:5]

    auto var = torch::rand({1, 3, 224, 224});
    auto x = var.select(0, 0).select(0, 0).slice(0, 0, 3).slice(1, 0, 4);
    cout << x << endl;
    
    // 输出:
    // 0.5341  0.5749  0.9831  0.5344
    // 0.1251  0.3488  0.0220  0.0004
    // 0.0368  0.8825  0.9089  0.1680
    // [ CPUFloatType{3,4} ]
    
  • 输出任意Head-K个元素: 按内存取前k个元素,再转成二维矩阵的形式

    void print_head_k(torch::Tensor data, int m, int n, const string &comment = "") {
        cout << endl;
        if (!comment.empty()) cout << comment << ":" << endl;
        data = data.flatten();
        data = data.slice(0, 0, m * n).contiguous().view({m, n});
        cout << data << endl;
    }
    
    auto var = torch::rand({1, 3, 224, 224});
    print_head_k(var, 3, 4, "var");
    
    // var:
    // 0.9258  0.1272  0.8560  0.7578
    // 0.7780  0.7363  0.0554  0.8341
    // 0.5028  0.9969  0.0497  0.8425
    // [ CPUFloatType{3,4} ]
    

squeeze() 和 unsqueeze()

维度压缩 squeeze:

  • 指定dim
    auto var = torch::rand({1, 1, 224, 224, 1});
    auto a = var.squeeze(0);  // [1, 224, 224, 1]
    auto b = var.squeeze(1);  // [1, 224, 224, 1]
    auto c = var.squeeze(2);  // [1, 1, 224, 224, 1]
    auto d = var.squeeze(3);  // [1, 1, 224, 224, 1]
    auto e = var.squeeze(4);  // [1, 1, 224, 224]
    
  • 参数 dim 为空时,自动压缩第一个可压缩的维度

维度扩张 unsqueeze:

  • 指定dim

    auto var = torch::rand({4, 12800, 80});
    auto a = var.unsqueeze(0);  // [1, 4, 12800, 80]
    auto b = var.unsqueeze(1);  // [4, 1, 12800, 80]
    auto c = var.unsqueeze(2);  // [4, 12800, 1, 80]
    auto d = var.unsqueeze(3);  // [4, 12800, 80, 1]
    
  • 无默认参数


transpose()

参数:dim1, dim2 位置平等

auto var = torch::rand({1, 3, 224, 224});
auto a = var.transpose(0, 1);  // [3, 1, 224, 224]
auto b = var.transpose(1, 0);  // [3, 1, 224, 224]

nonzero()

判断Tensor元素是否为0,输出为坐标矩阵 n u m × s i z e num×size num×size

auto var = torch::rand({1, 3, 30, 30});
auto coors = var.nonzero();
cout << coors.sizes() << endl;

// 输出:
// [2700, 4]

size() 、sizes() 和 numel()

size()有参数dim,sizes() 无参数,numel 返回元素个数

auto var = torch::rand({1, 3, 30, 30});
cout << var.sizes() << endl;	// [1, 3, 30, 30]
cout << var.size(0) << endl;	//  1
cout << var.size(1) << endl;	//  3
cout << var.size(2) << endl;	// 	30
cout << var.size(3) << endl;	//	30
cout << var.numel() << endl;	//  2700

元素访问 data_ptr()

可以用指针直接修改元数据,注意类型对应。如果切片后,应该使之内存连续 .contiguous()

torch::Tensor result = torch::rand({32, 5}, torch::kFloat32);
auto p = static_cast<float *>(result[0].data_ptr());
cout << p[4] << endl;

p[4] = 0;
auto q = static_cast<float *>(result[0].data_ptr());
cout << q[4] << endl;

也可以用官方示例:

torch::Tensor foo = torch::rand({12, 12});
auto foo_a = foo.accessor<float, 2>();
for (int i = 0; i < foo_a.size(0); i++) {
     float x = foo_a[i][i];
     cout << x << endl;
 }

(To be continue)

  • 0
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值