LibTorch之图像分类

L i b T o r c h 之图像分类 LibTorch之图像分类 LibTorch之图像分类

数据集地址:https://download.pytorch.org/tutorial/hymenoptera_data.zip

LibTorch之全连接层(torch::nn::Linear)使用

卷积层

LibTorch实现MLP(多层感知机)

LibTorch实现LeNet

训练

#include<opencv2/opencv.hpp>
#include <torch/torch.h>
#include <torch/script.h> 
#include <filesystem>
using namespace std;
namespace fs = std::filesystem;

vector<pair<string, int>> get_imgs_labels(const std::string& data_dir, map<string, int> dict_label)
{
	// 1.定义标签
	//map<string, int> dict_label;
	//dict_label.insert(pair<string, int>("ants", 0));
	//dict_label.insert(pair<string, int>("bees", 1));
	// 2.定义存储图像路径和标签的vector
	vector<pair<string, int>> data_info;
	// 3.读取图像和对应label放入data_info
	// 遍历字典,读取图像路径和对应label
	for (map<string, int>::iterator it = dict_label.begin(); it != dict_label.end(); it++)
	{
		// 遍历目录查找
		for (const auto& file_path : fs::directory_iterator(data_dir))
		{
			if (file_path.path().filename() == it->first) {
				// 遍历所有图像路径
				for (const auto& img_path : fs::directory_iterator(data_dir + "\\" + it->first))
				{
					//std::cout << img_path.path() << std::endl;
					data_info.push_back(pair<string, int>(img_path.path().string(), it->second));
				}
			}
			//std::cout << entry.path() << std::endl;
		}
		//printVector(data_info);
	}

	return data_info;
}

/// <summary>
/// 数据集处理模块类
/// </summary>
class MyDataset :public torch::data::Dataset<MyDataset> {
private:
	vector<pair<string, int>> data_info;
	torch::Tensor imgs, labels;
public:
	// 构造器:一般用于确定数据集和预处理形式
	MyDataset(const std::string& data_dir,std::map<string,int> dict_label);
	// get_item数据处理:对读取的数据进行预处理
	torch::data::Example<> get(size_t index) override;
	// 返回数据数量
	torch::optional<size_t> size() const override {
		return data_info.size();
	};
};

/// <summary>
/// 根据数据集路径和对应的标签列表,配对训练数据
/// </summary>
/// <param name="data_dir"></param>
/// <param name="dict_label"></param>
MyDataset::MyDataset(const std::string& data_dir, std::map<string, int> dict_label) {
	// 获取训练数据
	data_info = get_imgs_labels(data_dir, dict_label);

}

/// <summary>
/// 对数据进行预处理,并返回成对的实例Example{data,label}
/// </summary>
/// <param name="index"></param>
/// <returns></returns>
torch::data::Example<> MyDataset::get(size_t index)
{
	// 获取图像路径
	auto img_path = data_info[index].first;
    // 确定label
	auto label = data_info[index].second;
	// opencv根据图像路径读取图像
	auto image = cv::imread(img_path);
	cout << image.size() << endl;
	//获取通道数
	int channels = image.channels();
	cout<<"channels:" <<channels << endl;

	// resize图像大小
	cv::resize(image, image, cv::Size(224, 224));
	// mat转tensor
	auto input_tensor = torch::from_blob(image.data, { image.rows, image.cols, 3 }, torch::kByte).permute({ 2, 0, 1 }).to(torch::kFloat32) / 225.0;
	cout << input_tensor.sizes() << endl;
	
	// int转tensor
	torch::Tensor label_tensor = torch::tensor(label);
	return {input_tensor,label_tensor };

}


/// <summary>
/// LeNet实现类
/// </summary>
class LeNet :public torch::nn::Module {
public:
	// 构造器
	LeNet(int num_classes, int num_linear);
	// 前向传播
	torch::Tensor forward(torch::Tensor x);
private:
	// 具体实现放到构造器实现中
	torch::nn::Conv2d conv1{ nullptr };
	torch::nn::Conv2d conv2{ nullptr };
	torch::nn::Linear fc1{ nullptr };
	torch::nn::Linear fc2{ nullptr };
	torch::nn::Linear fc3{ nullptr };
};

LeNet::LeNet(int num_classes, int num_linear)
{
	conv1 = register_module("conv1", torch::nn::Conv2d(torch::nn::Conv2dOptions(3, 6, 5)));
	conv2 = register_module("conv2", torch::nn::Conv2d(torch::nn::Conv2dOptions(6, 16, 5)));
	fc1 = register_module("fc1", torch::nn::Linear(torch::nn::LinearOptions(num_linear, 128)));
	fc2 = register_module("fc2", torch::nn::Linear(torch::nn::LinearOptions(128, 32)));
	fc3 = register_module("fc3", torch::nn::Linear(torch::nn::LinearOptions(32, num_classes)));
}

torch::Tensor LeNet::forward(torch::Tensor x)
{
	auto out = torch::relu(conv1->forward(x));
	out = torch::max_pool2d(out, 2);
	out = torch::relu(conv2(out));
	out = torch::max_pool2d(out, 2);
	out = out.view({ 1, -1 });
	out = torch::relu(fc1(out));
	out = torch::relu(fc2(out));
	out = fc3(out);
	return out;
}



int main()
{
	try
	{
		map<string, int> dict_label;
		dict_label.insert(pair<string, int>("ants", 0));
		dict_label.insert(pair<string, int>("bees", 1));
		// 设置dataset
		auto dataset_train = MyDataset("D:\\dataset\\hymenoptera_data\\train", dict_label).map(torch::data::transforms::Stack<>());
		// batchszie
		int batchSize = 1;
		// 设置dataloader
		auto dataLoader = torch::data::make_data_loader<torch::data::samplers::SequentialSampler>(std::move(dataset_train), batchSize);
		// 打印
		//for (auto& batch : * dataLoader) {

		//	auto data = batch.data;
		//	auto target = batch.target;
		//	std::cout << data.sizes() << std::endl;
		//	//std::cout << data.max() << std::endl;
		//	//std::cout << data << std::endl;
		//	std::cout << target << std::endl;
		//	int ssss;
		//	cin >> ssss;
		//}
		//auto net = LeNet(5, 44944);

		std::shared_ptr<LeNet> net = std::make_shared<LeNet>(2, 44944);

		
		// 优化器
		torch::optim::SGD optimizer(net->parameters(), /*lr=*/0.01);

		
		for (size_t epoch = 1; epoch <= 10; ++epoch) {
			size_t batch_index = 0;
			// 遍历数据集
			for (auto& batch : *dataLoader) {
				// 梯度清零.
				optimizer.zero_grad();
				// 前向传播
				torch::Tensor prediction = net->forward(batch.data);
				cout << "prediction:" << prediction << endl;
				cout << "target:" << batch.target << endl;

				// 计算损失
				torch::Tensor loss = torch::nll_loss(prediction, batch.target);
				cout <<"loss:" << loss << endl;
				// 反向传播
				loss.backward();
				// 更新梯度
				optimizer.step();
				// 间隔 x batch 进行loss打印和模型保存
				if (++batch_index % 20 == 0) {
					std::cout << "Epoch: " << epoch << " | Batch: " << batch_index
						<< " | Loss: " << loss << std::endl;
					// 保存模型
					torch::save(net, "net.pt");

					cout << net->parameters() << endl;

				}
				
			}
		}


	}
	catch (const std::exception& e)
	{
		// step5:打印报错
		cout << e.what() << endl;
	}
	
	return 0;
}

在这里插入图片描述

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值