libtorch-UNET Carvana车辆轮廓分割

libtorch-UNET Carvana车辆轮廓分割

libtorch-UNET Carvana车辆轮廓分割
libtorch-VGG cifar10分类
libtorch-FCN CamVid语义分割
libtorch-RDN DIV2K超分重建
libtorch-char-rnn-classification libtorch官网系列教程
libtorch-char-rnn-generation libtorch官网系列教程
libtorch-char-rnn-shakespeare libtorch官网系列教程
libtorch-minst libtorch官网系列教程
————————————————

                        版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。

原文链接:https://blog.csdn.net/barcaFC/article/details/139865461



前言

libtorch C++ pytorch unet Carvana车辆轮廓分割
从数据集准备、训练、推理 全部由libtorch c++完成
运行环境:
操作系统:windows 64位
opencv3.4.1
libtorch 任意版本 64位
visual stdio 2017 64位编译
相关数据集 Carvana
参考论文 https://paperswithcode.com/paper/u-net-convolutional-networks-for-biomedical


一、UNET是什么?

UNET通过U型网络结构,在像素级上通过高低维度对称特征训练融合推理每个像素分类。U型网络左右两条通道执行常规特征提取和训练推理,两通道之间同时进行对应特征的融合计算,使得高维特征和低维特征的有效融合提取,极大提高有限样本特征提取效果,从而使得较少的样本能够达到高效的训练效果和推理速度。

二、使用步骤

1.设置参数

代码如下:

	args::ArgumentParser parser("Train the UNet on images and target masks.", "This goes after the options.");
	args::HelpFlag help(parser, "help", "Display this help menu", { 'h', "help" });

	args::ValueFlag<int> n_class(parser, "n_class", "Pixel classification", { "n_class" }, 1);
	args::ValueFlag<int> n_channels(parser, "n_channels", "RGB images", { "n_channels" }, 3);
	args::ValueFlag<bool> isbilinear(parser, "bilinear", "if net.bilinear else Transposed conv", { "bilinear" }, 1);
	args::ValueFlag<int> epochs(parser, "epochs", "number of epochs to train", { "epochs" }, 5);
	args::ValueFlag<int> batch_size(parser, "batch_size", "train patch size", { "batch_size" }, 1);
	args::ValueFlag<double> lr(parser, "lr", "learning rate", { "lr" }, 1e-4);
	args::ValueFlag<float> scale(parser, "scale", "Downscaling factor of the images", { "scale" }, 0.5);
	args::ValueFlag<std::string> str_imgs_train_dir(parser, "imgs_train_dir", "train dataset directory", { "imgs_train_dir" }, "..\\Carvana\\imgs_train");
	args::ValueFlag<std::string> str_dir_train_mask(parser, "dir_train_mask", "label dataset directory", { "dir_train_mask" }, "..\\Carvana\\masks_train");
	args::ValueFlag<std::string> str_imgs_valid_dir(parser, "imgs_valid_dir", "valid dataset directory", { "imgs_valid_dir" }, "..\\Carvana\\imgs_valid");
	args::ValueFlag<std::string> str_dir_valid_mask(parser, "dir_valid_mask", "label valid dataset directory", { "dir_valid_mask" }, "..\\Carvana\\masks_valid");

	try
	{
		parser.ParseCLI(argc, argv);
	}
	catch (args::Help)
	{
		std::cout << parser;
		return 0;
	}
	catch (args::ParseError e)
	{
		std::cerr << e.what() << std::endl;
		std::cerr << parser;
		return 1;
	}
	catch (args::ValidationError e)
	{
		std::cerr << e.what() << std::endl;
		std::cerr << parser;
		return 1;
	}

	int Epochs = (int)args::get(epochs);
	int batch = (int)args::get(batch_size);
	double learning_rate = (double)args::get(lr);
	float scale_factor = (float)args::get(scale);
	
	int nclass = (int)args::get(n_class);
	int channels = (int)args::get(n_channels);
	bool bilinear = (bool)args::get(isbilinear);
	string imgs_dir_train = (string)args::get(str_imgs_train_dir);
	string dir_mask_train = (string)args::get(str_dir_train_mask);
	string imgs_dir_valid = (string)args::get(str_imgs_valid_dir);
	string dir_mask_valid = (string)args::get(str_dir_valid_mask);
	cout << "Epochs : " << Epochs << endl;
	cout << "batch : " << batch << endl;
	cout << "learning_rate : " << learning_rate << endl;
	cout << "scale : " << scale_factor << endl;
	
	cout << "n_channels : " << channels << endl;
	cout << "n_class : " << nclass << endl;
	cout << "bilinear : " << bilinear << endl;
	cout << "imgs_train : " << imgs_dir_train << endl;
	cout << "mask_train : " << dir_mask_train << endl;
	cout << "imgs_valid : " << imgs_dir_valid << endl;
	cout << "mask_valid : " << dir_mask_valid << endl;

2.读入数据

代码如下:

	std::vector<double> norm_mean = { 0.485, 0.456, 0.406 };
	std::vector<double> norm_std = { 0.229, 0.224, 0.225 };
	auto train_dataset = CarvanaDataset(imgs_dir_train, dir_mask_train, scale_factor)/*.map(torch::data::transforms::Normalize<>(norm_mean, norm_std))*/;
	auto train_loader = torch::data::make_data_loader<torch::data::samplers::RandomSampler>(
		std::move(train_dataset), batch);

	auto valid_dataset = CarvanaDataset(imgs_dir_valid, dir_mask_valid, scale_factor);
	auto valid_loader = torch::data::make_data_loader<torch::data::samplers::RandomSampler>(
		std::move(valid_dataset), batch);

3.UNET网络

代码如下:

struct UNET : public torch::nn::Module
{
	UNET(int n_channels = 3, int n_classes = 1, bool bilinear = true):n_channels(n_channels), n_classes(n_classes), bilinear_(bilinear)
	{
		inc = register_module("inc", std::make_shared<DoubleConv>(n_channels, 64));
		down1 = register_module("down1", std::make_shared<Down>(64, 128));
		down2 = register_module("down2", std::make_shared<Down>(128, 256));
		down3 = register_module("down3", std::make_shared<Down>(256, 512));
		if (bilinear_) factor = 2; else factor = 1;
		down4 = register_module("down4", std::make_shared<Down>(512, 1024 / factor));
		up1 = register_module<Up>("up1", std::make_shared<Up>(1024, 512 / factor, bilinear_));
		up2 = register_module<Up>("up2", std::make_shared<Up>(512, 256 / factor, bilinear));
		up3 = register_module<Up>("up3", std::make_shared<Up>(256, 128 / factor, bilinear));
		up4 = register_module<Up>("up4", std::make_shared<Up>(128, 64, bilinear));
		outc = register_module<OutConv>("outc", std::make_shared<OutConv>(64, n_classes));
	}

	torch::Tensor forward(torch::Tensor x)
	{
		x1 = inc(x);
		x2 = down1(x1);
		x3 = down2(x2);
		x4 = down3(x3);
		x5 = down4(x4);
		x = up1(x5, x4);
		x = up2(x, x3);
		x = up3(x, x2);
		x = up4(x, x1);
		logits = outc(x);
		return logits;
	}

	torch::nn::ModuleHolder<DoubleConv>	inc{ nullptr };
	torch::nn::ModuleHolder<Down> down1{ nullptr }, down2{ nullptr }, down3{ nullptr }, down4{ nullptr };
	torch::nn::ModuleHolder<Up>	up1{ nullptr }, up2{ nullptr }, up3{ nullptr }, up4{ nullptr };
	torch::nn::ModuleHolder<OutConv> outc{ nullptr };

	torch::Tensor x1,x2,x3,x4,x5;
	torch::Tensor logits;
private:
	int n_channels;
	int n_classes;
	bool bilinear_;
	int factor;
};

总结

资源一直上传失败,过后会更新下载地址

  • 5
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值