目录
1. dataset.h
这里以图像分类数据集(昆虫分类数据集)为例,头文件如下。
(1)首先,自定义的数据集myDataset需继承torch::data::Dataset;
(2)重写基类的两个方法:get()和size();
#ifndef DATASET_H
#define DATASET_H
#endif // DATASET_H
#include<torch/torch.h>
#include<vector>
#include<string>
#include <io.h>
#include<opencv2/opencv.hpp>
// 工具函数:遍历该目录下的后缀为type的图片
void load_data_from_folder(std::string path, std::string type, std::vector<std::string> &list_images, std::vector<int> &list_labels, int label);
class myDataset:public torch::data::Dataset<myDataset>
{
public:
int num_classes = 0;
myDataset(std::string image_dir, std::string type);
// Override get() function to return tensor at location index
torch::data::Example<> get(size_t index) override;
// Return the length of data
torch::optional<size_t> size() const override;
private:
std::vector<std::string> image_paths; // 所有图片的完整路径
std::vector<int> labels; // 所有标注图像的对应类别。
};
2. dataset.cpp
重点是重写get()方法,返回的shape:image(b,c,h,w) , label(b,1).
#include<dataSet.h>
// 工具函数:不用太关注此函数,
// 只需要知道功能是遍历该目录下的后缀为type的图片,返回图片路径
void load_data_from_folder(std::string path, std::string type, std::vector<std::string>& list_images, std::vector<int>& list_labels, int label)
{
long long hFile = 0; //句柄
struct _finddata_t fileInfo;
std::string pathName;
if ((hFile = _findfirst(pathName.assign(path).append("\\*.*").c_str(), &fileInfo)) == -1)
{
return;
}
do
{
const char* s = fileInfo.name;
const char* t = type.data();
if (fileInfo.attrib & _A_SUBDIR) // is sub dir
{
// traverse all files or dir of subdir
if (strcmp(s, ".") == 0 || strcmp(s, "..") == 0)
continue;
std::string sub_path = path + "\\" + fileInfo.name;
label++;
load_data_from_folder(sub_path, type, list_images, list_labels, label);
}
else // or not sub, is file
{
if (strstr(s, t)) // whether the suffix of the file is "t"
{
std::string image_path = path + "\\" + fileInfo.name;
list_images.push_back(image_path);
list_labels.push_back(label);
}
}
} while (_findnext(hFile, &fileInfo) == 0);
return;
}
// 1.构造函数,初始化myDataset对象
// 遍历路径下所有图片,获取路径和对应的类别
myDataset::myDataset(std::string image_dir, std::string type) { // 构造函数,初始化myDataset对象,参数:路径,文件名后缀
load_data_from_folder(image_dir, std::string(type), image_paths, labels, num_classes);
}
// 2. 重写get方法
// 获取一张图片和对应的类别
torch::data::Example<> myDataset::get(size_t index){
std::string image_path = image_paths.at(index);
cv::Mat image = cv::imread(image_path);
cv::resize(image, image, cv::Size(224, 224));
int label = labels.at(index);
torch::Tensor img_tensor = torch::from_blob(image.data, { image.rows, image.cols, 3 }, torch::kByte).permute({ 2, 0, 1 }); // Channels x Height x Width
torch::Tensor label_tensor = torch::full({ 1 }, label); // sizes: {1}
return { img_tensor.clone(), label_tensor.clone() }; // sizes: 1chw, {1}
}
// 3. 重写size方法
torch::optional<size_t> myDataset::size() const{
return image_paths.size();
};
3. main函数
(1)创建dataset对象;
(2)创建dataloader对象;
(3)依次取出batch.data, batch.target.
#include <dataSet.h>
int main(int argc, char* argv[])
{
int batch_size = 2;
std::string img_dir = "F:\\zxq\\data\\hymenoptera_data\\train";
// 1. 创建dataset对象。
// 将dataset对象转成MapDataset,如此可以将指定的transform应用到该数据集。
// 当前只有一个transform: stack,statck all tensor to one tensor.
auto mdataset = myDataset(img_dir, ".jpg").map(torch::data::transforms::Stack<>());
// 2. 创建dataloader对象
auto mdataloader = torch::data::make_data_loader<torch::data::samplers::RandomSampler>(std::move(mdataset), batch_size);
// 3. 遍历,取出数据,进行使用
for (auto& batch : *mdataloader) {
auto data = batch.data;
auto target = batch.target;
std::cout << "data shape: " << data.sizes() << "target shape: " << target.sizes() << std::endl;
}
}
参考:https://github.com/AllentDan/LibtorchTutorials/blob/main/lesson4-DatasetUtilization/main.cpp