C++调用pytorch的ResNet模型提取图片特征向量

ResNet模型的实现大多依赖于PyTorch框架,本文讲述了如何通过C++调用TorchScript模型,实现用C++进行ResNet特征提取,并生成特征向量库进行保存。

功能概述

本代码实现了一个基于 LibTorch 和 OpenCV 的图像特征提取工具,能够:

  1. 批量处理图像:遍历指定目录下的所有PNG图像(支持嵌套子目录)。

  2. 自动设备选择:优先使用GPU加速,若无GPU则回退到CPU。

  3. 图像标准化处理:将图像缩放到固定大小(默认224x224),并转换为PyTorch张量。

  4. 特征提取与保存:使用预训练的TorchScript模型提取特征,并将结果保存为.pt文件。


输出文件格式

生成的.pt文件包含两个张量:

  • features: 形状为 [N, D] 的特征矩阵(N为图像数,D为特征维度)

  • labels: 形状为 [N] 的类别索引


环境与资源依赖
  • LibTorch:CPU版可下载这个文件 libtorch-win-shared-with-deps-2.6.0+cpu.zip

  • OpenCV:用于图像读取与预处理(我这里用的是4.0.0版本)

  • C++17编译器:支持std::filesystem的编译器(如GCC 9+/Clang 10+/MSVC 2019+)

  • TorchScript模型:去掉了最后的全连接层直接输出向量,可下载 resnet18.pt

  • names文件:每个类别名称单独占一行的txt文件即可


具体代码实现
#include <torch/torch.h>
#include <torch/script.h>
#include <opencv2/opencv.hpp>
#include <iostream>
#include <filesystem>
#include <chrono>
#include <vector>
#include <fstream>
#include <unordered_map>

namespace fs = std::filesystem;

/**
 * @brief 提取图像特征并保存到文件
 * 
 * @param model_path   TorchScript模型路径
 * @param root_dir     包含类别子文件夹的图像根目录
 * @param output_file  输出特征文件的路径(.pt格式)
 * @param names_file   类别名称与索引的映射文件(每行一个类别名)
 * @param image_size   目标图像大小(默认224x224,适用于ResNet等模型)
 */
void getfeaturelib(const std::string& model_path, 
                   const std::string& root_dir, 
                   const std::string& output_file, 
                   const std::string& names_file,
                   const int image_size = 224) {
    // -------------------- 1. 加载模型 --------------------
    torch::jit::script::Module model = torch::jit::load(model_path);
    model.eval();  // 设置为推理模式

    // 自动选择设备(优先使用GPU)
    torch::Device device(torch::kCPU);
    if (torch::cuda::is_available()) {
        device = torch::Device(torch::kCUDA);
        std::cout << "Using GPU acceleration." << std::endl;
    }
    model.to(device);  // 将模型移至设备

    // ----------- 2. 读取类别名称与索引的映射文件 -----------
    std::unordered_map<std::string, int64_t> class_name_to_index;
    std::ifstream names_input(names_file);
    std::string line;
    int64_t index = 0;
    while (std::getline(names_input, line)) {
        class_name_to_index[line] = index++;
    }
    names_input.close();

    // ----------- 3. 遍历图像并提取特征 -----------
    std::vector<torch::Tensor> features;
    std::vector<int64_t> class_indexes;

    // 统计总图像数量(用于进度显示)
    int total_images = 0;
    for (const auto& entry : fs::recursive_directory_iterator(root_dir)) {
        if (entry.path().extension() == ".png") total_images++;
    }

    auto start = std::chrono::high_resolution_clock::now();  // 开始计时
    int processed_images = 0;

    for (const auto& entry : fs::recursive_directory_iterator(root_dir)) {
        if (entry.path().extension() == ".png") {
            const std::string image_path = entry.path().string();
            const std::string class_name = entry.path().parent_path().filename().string();

            // 3.1 读取图像
            cv::Mat image = cv::imread(image_path);
            if (image.empty()) {
                std::cerr << "Warning: Skipping invalid image: " << image_path << std::endl;
                continue;
            }

            // 3.2 将图像缩放到固定大小(默认224x224)
            cv::resize(image, image, cv::Size(image_size, image_size));

            // 3.3 将OpenCV图像转换为PyTorch张量
            torch::Tensor input_tensor = torch::from_blob(
                image.data, 
                {1, image_size, image_size, 3},  // 形状: [Batch=1, Height=224, Width=224, Channels=3]
                torch::kByte
            )
            .permute({0, 3, 1, 2})  // 转换为NCHW格式: [1, 3, 224, 224]
            .to(device)              // 移至GPU/CPU
            .to(torch::kFloat)       // 转换为浮点数
            / 255.0;                 // 归一化到[0,1]

            // 3.4 提取特征
            torch::NoGradGuard no_grad;  // 禁用梯度计算
            torch::Tensor feature = model.forward({input_tensor}).toTensor()
                                      .cpu()       // 移回CPU
                                      .flatten();   // 展平为1D向量

            // 3.5 保存结果
            features.push_back(feature);
            class_indexes.push_back(class_name_to_index[class_name]);

            // 显示进度
            processed_images++;
            std::cout << "Processed: " << processed_images << "/" << total_images 
                      << " (" << (processed_images * 100 / total_images) << "%) \r" << std::flush;
        }
    }

    // -------------------- 4. 保存结果 --------------------
    auto end = std::chrono::high_resolution_clock::now();
    std::chrono::duration<double> total_time = end - start;

    // 合并特征和标签
    torch::save(
        {torch::stack(features), torch::tensor(class_indexes)}, 
        output_file
    );

    // 输出统计信息
    std::cout << "\n\n=============== Summary ===============" << std::endl;
    std::cout << "Output saved to: " << output_file << std::endl;
    std::cout << "Total time: " << total_time.count() << " seconds" << std::endl;
    std::cout << "Speed: " << (total_images / total_time.count()) << " images/s" << std::endl;
    std::cout << "Device: " << (device.is_cuda() ? "GPU" : "CPU") << std::endl;
}

// 主函数:解析命令行参数并调用特征提取函数
int main(int argc, char** argv) {
    if (argc != 5) {
        std::cerr << "Usage: " << argv[0] 
                  << " <model.pt> <image_dir> <output.pt> <names.txt>" << std::endl;
        return 1;
    }

    // 调用特征提取函数,默认图像大小为224x224
    getfeaturelib(argv[1], argv[2], argv[3], argv[4], 224);
    return 0;
}

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值