libtorch C++ 创建自己的Dataset

目录

1. dataset.h

2. dataset.cpp

3. main函数


1. dataset.h

这里以图像分类数据集(昆虫分类数据集)为例,头文件如下。

(1)首先,自定义的数据集myDataset需继承torch::data::Dataset;

(2)重写基类的两个方法:get()和size();

#ifndef DATASET_H
#define DATASET_H
#endif // DATASET_H

#include<torch/torch.h>
#include<vector>
#include<string>
#include <io.h>
#include<opencv2/opencv.hpp>

// 工具函数:遍历该目录下的后缀为type的图片
void load_data_from_folder(std::string path, std::string type, std::vector<std::string> &list_images, std::vector<int> &list_labels, int label);

class myDataset:public torch::data::Dataset<myDataset>
{
public:
    int num_classes = 0;
    myDataset(std::string image_dir, std::string type);
    // Override get() function to return tensor at location index
    torch::data::Example<> get(size_t index) override;
    // Return the length of data
    torch::optional<size_t> size() const override;
private:
    std::vector<std::string> image_paths;  // 所有图片的完整路径
    std::vector<int> labels;  // 所有标注图像的对应类别。
};

2. dataset.cpp

重点是重写get()方法,返回的shape:image(b,c,h,w) , label(b,1).

#include<dataSet.h>

// 工具函数:不用太关注此函数,
// 只需要知道功能是遍历该目录下的后缀为type的图片,返回图片路径
void load_data_from_folder(std::string path, std::string type, std::vector<std::string>& list_images, std::vector<int>& list_labels, int label)
{
    long long hFile = 0; //句柄
    struct _finddata_t fileInfo;
    std::string pathName;
    if ((hFile = _findfirst(pathName.assign(path).append("\\*.*").c_str(), &fileInfo)) == -1)
    {
        return;
    }
    do
    {
        const char* s = fileInfo.name;
        const char* t = type.data();
        if (fileInfo.attrib & _A_SUBDIR)  // is sub dir
        {
            // traverse all files or dir of subdir
            if (strcmp(s, ".") == 0 || strcmp(s, "..") == 0) 
                continue;
            std::string sub_path = path + "\\" + fileInfo.name;
            label++;
            load_data_from_folder(sub_path, type, list_images, list_labels, label);
        }
        else // or not sub, is file
            {
            if (strstr(s, t))  // whether the suffix of the file is "t"
            {
                std::string image_path = path + "\\" + fileInfo.name;
                list_images.push_back(image_path);
                list_labels.push_back(label);
            }
        }
    } while (_findnext(hFile, &fileInfo) == 0);

    return;
}

// 1.构造函数,初始化myDataset对象
// 遍历路径下所有图片,获取路径和对应的类别
myDataset::myDataset(std::string image_dir, std::string type) {  // 构造函数,初始化myDataset对象,参数:路径,文件名后缀
    load_data_from_folder(image_dir, std::string(type), image_paths, labels, num_classes);
}
// 2. 重写get方法
// 获取一张图片和对应的类别
torch::data::Example<> myDataset::get(size_t index){
    std::string image_path = image_paths.at(index);
    cv::Mat image = cv::imread(image_path);
    cv::resize(image, image, cv::Size(224, 224));
    int label = labels.at(index);
    torch::Tensor img_tensor = torch::from_blob(image.data, { image.rows, image.cols, 3 }, torch::kByte).permute({ 2, 0, 1 }); // Channels x Height x Width
    torch::Tensor label_tensor = torch::full({ 1 }, label);  // sizes: {1}
    return { img_tensor.clone(), label_tensor.clone() };  // sizes: 1chw, {1}
}
// 3. 重写size方法
torch::optional<size_t> myDataset::size() const{
    return image_paths.size();
};

3. main函数

(1)创建dataset对象;

(2)创建dataloader对象;

(3)依次取出batch.data, batch.target. 

#include <dataSet.h>

int main(int argc, char* argv[])
{
    int batch_size = 2;
    std::string img_dir = "F:\\zxq\\data\\hymenoptera_data\\train";
    // 1. 创建dataset对象。
    // 将dataset对象转成MapDataset,如此可以将指定的transform应用到该数据集。
    // 当前只有一个transform: stack,statck all tensor to one tensor.
    auto mdataset = myDataset(img_dir, ".jpg").map(torch::data::transforms::Stack<>());  
    // 2. 创建dataloader对象
    auto mdataloader = torch::data::make_data_loader<torch::data::samplers::RandomSampler>(std::move(mdataset), batch_size);
    // 3. 遍历,取出数据,进行使用
    for (auto& batch : *mdataloader) {
        auto data = batch.data;
        auto target = batch.target;
        std::cout << "data shape: " << data.sizes() << "target shape: " << target.sizes() << std::endl;
    }

}

参考:https://github.com/AllentDan/LibtorchTutorials/blob/main/lesson4-DatasetUtilization/main.cpp

  • 1
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Mr.Q

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值