L
i
b
T
o
r
c
h
之图像分类
LibTorch之图像分类
LibTorch之图像分类
训练
#include<opencv2/opencv.hpp>
#include <torch/torch.h>
#include <torch/script.h>
#include <filesystem>
using namespace std;
namespace fs = std::filesystem;
vector<pair<string, int>> get_imgs_labels(const std::string& data_dir, map<string, int> dict_label)
{
vector<pair<string, int>> data_info;
for (map<string, int>::iterator it = dict_label.begin(); it != dict_label.end(); it++)
{
for (const auto& file_path : fs::directory_iterator(data_dir))
{
if (file_path.path().filename() == it->first) {
for (const auto& img_path : fs::directory_iterator(data_dir + "\\" + it->first))
{
data_info.push_back(pair<string, int>(img_path.path().string(), it->second));
}
}
}
}
return data_info;
}
class MyDataset :public torch::data::Dataset<MyDataset> {
private:
vector<pair<string, int>> data_info;
torch::Tensor imgs, labels;
public:
MyDataset(const std::string& data_dir,std::map<string,int> dict_label);
torch::data::Example<> get(size_t index) override;
torch::optional<size_t> size() const override {
return data_info.size();
};
};
MyDataset::MyDataset(const std::string& data_dir, std::map<string, int> dict_label) {
data_info = get_imgs_labels(data_dir, dict_label);
}
torch::data::Example<> MyDataset::get(size_t index)
{
auto img_path = data_info[index].first;
auto label = data_info[index].second;
auto image = cv::imread(img_path);
cout << image.size() << endl;
int channels = image.channels();
cout<<"channels:" <<channels << endl;
cv::resize(image, image, cv::Size(224, 224));
auto input_tensor = torch::from_blob(image.data, { image.rows, image.cols, 3 }, torch::kByte).permute({ 2, 0, 1 }).to(torch::kFloat32) / 225.0;
cout << input_tensor.sizes() << endl;
torch::Tensor label_tensor = torch::tensor(label);
return {input_tensor,label_tensor };
}
class LeNet :public torch::nn::Module {
public:
LeNet(int num_classes, int num_linear);
torch::Tensor forward(torch::Tensor x);
private:
torch::nn::Conv2d conv1{ nullptr };
torch::nn::Conv2d conv2{ nullptr };
torch::nn::Linear fc1{ nullptr };
torch::nn::Linear fc2{ nullptr };
torch::nn::Linear fc3{ nullptr };
};
LeNet::LeNet(int num_classes, int num_linear)
{
conv1 = register_module("conv1", torch::nn::Conv2d(torch::nn::Conv2dOptions(3, 6, 5)));
conv2 = register_module("conv2", torch::nn::Conv2d(torch::nn::Conv2dOptions(6, 16, 5)));
fc1 = register_module("fc1", torch::nn::Linear(torch::nn::LinearOptions(num_linear, 128)));
fc2 = register_module("fc2", torch::nn::Linear(torch::nn::LinearOptions(128, 32)));
fc3 = register_module("fc3", torch::nn::Linear(torch::nn::LinearOptions(32, num_classes)));
}
torch::Tensor LeNet::forward(torch::Tensor x)
{
auto out = torch::relu(conv1->forward(x));
out = torch::max_pool2d(out, 2);
out = torch::relu(conv2(out));
out = torch::max_pool2d(out, 2);
out = out.view({ 1, -1 });
out = torch::relu(fc1(out));
out = torch::relu(fc2(out));
out = fc3(out);
return out;
}
int main()
{
try
{
map<string, int> dict_label;
dict_label.insert(pair<string, int>("ants", 0));
dict_label.insert(pair<string, int>("bees", 1));
auto dataset_train = MyDataset("D:\\dataset\\hymenoptera_data\\train", dict_label).map(torch::data::transforms::Stack<>());
int batchSize = 1;
auto dataLoader = torch::data::make_data_loader<torch::data::samplers::SequentialSampler>(std::move(dataset_train), batchSize);
std::shared_ptr<LeNet> net = std::make_shared<LeNet>(2, 44944);
torch::optim::SGD optimizer(net->parameters(), 0.01);
for (size_t epoch = 1; epoch <= 10; ++epoch) {
size_t batch_index = 0;
for (auto& batch : *dataLoader) {
optimizer.zero_grad();
torch::Tensor prediction = net->forward(batch.data);
cout << "prediction:" << prediction << endl;
cout << "target:" << batch.target << endl;
torch::Tensor loss = torch::nll_loss(prediction, batch.target);
cout <<"loss:" << loss << endl;
loss.backward();
optimizer.step();
if (++batch_index % 20 == 0) {
std::cout << "Epoch: " << epoch << " | Batch: " << batch_index
<< " | Loss: " << loss << std::endl;
torch::save(net, "net.pt");
cout << net->parameters() << endl;
}
}
}
}
catch (const std::exception& e)
{
cout << e.what() << endl;
}
return 0;
}