libtorch-FCN CamVid语义分割

libtorch-FCN CamVid语义分割

libtorch-UNET Carvana车辆轮廓分割
libtorch-VGG cifar10分类
libtorch-FCN CamVid语义分割
libtorch-RDN DIV2K超分重建
libtorch-char-rnn-classification libtorch官网系列教程
libtorch-char-rnn-generation libtorch官网系列教程
libtorch-char-rnn-shakespeare libtorch官网系列教程
libtorch-minst libtorch官网系列教程
————————————————

                        版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。

原文链接:https://blog.csdn.net/barcaFC/article/details/140002947



前言

libtorch C++ pytorch fcn
从数据集准备、训练、推理 全部由libtorch c++完成
运行环境:
操作系统:windows 64位
opencv3.4.1
libtorch 任意版本 64位
visual studio 2017 64位编译
数据集:CamVid
参考论文:https://paperswithcode.com/paper/fully-convolutional-networks-for-semantic
python工程:Po-Chih Huang / @pochih


一、FCN是什么?

全链接网络,语义分割的奠基石,相关细节在工程代码注释中。

二、使用步骤

1.设置参数

代码如下(示例):

	// Device
	auto cuda_available = torch::cuda::is_available();
	torch::Device device(cuda_available ? torch::kCUDA : torch::kCPU);
	std::cout << (cuda_available ? "CUDA available. Training on GPU." : "Training on CPU.") << '\n';

	srand(time(NULL));
	args::ArgumentParser parser("This is a Semantic segmentation using fcn.", "This goes after the options.");
	args::HelpFlag help(parser, "help", "Display this help menu", { 'h', "help" });
	//# Hardware specifications
	args::ValueFlag<int> n_threads(parser, "n_threads", "number of threads for data loading", { "n_threads" }, 1);
	args::ValueFlag<bool> cpu(parser, "cpu", "use cpu only", { "cpu" }, 0);
	args::ValueFlag<int> n_GPUs(parser, "n_GPUs", "number of GPUs", { "n_GPUs" }, 1);

	//# Data specifications
	args::ValueFlag<std::string> root_dir(parser, "root_dir", "dataset directory", { "root_dir" }, "..\\CamVid");
	args::ValueFlag<std::string> model_dir(parser, "model_dir", "create dir for model", { "model_dir" }, "..\\save_models");
	args::ValueFlag<std::string> train_csv_dir(parser, "train_csv_dir", "dataset directory", { "train_csv_dir" }, "..\\CamVid\\train.csv");
	args::ValueFlag<std::string> val_csv_dir(parser, "val_csv_dir", "dataset directory", { "val_csv_dir" }, "..\\CamVid\\val.csv");

	//# Training specifications
	args::ValueFlag<int> n_class(parser, "n_class", "Pixel classification", { "n_class" }, 32);
	args::ValueFlag<int> batch_size(parser, "batch_size", "train patch size", { "batch_size" }, 1);
	args::ValueFlag<int> epochs(parser, "epochs", "number of epochs to train", { "epochs" }, 500);

	//# Optimization specifications
	args::ValueFlag<double> lr(parser, "lr", "learning rate", { "lr" }, 1e-4);
	args::ValueFlag<double> momentum(parser, "momentum", "RMSprop", { "momentum" }, 0);
	args::ValueFlag<double> weight_decay(parser, "weight_decay", "weight decay", { "weight_decay" }, 1e-5);
	args::ValueFlag<int> step_size(parser, "step_size", "lr scheduler step", { "step_size" }, 50);
	args::ValueFlag<double> gamma(parser, "gamma", "decay LR by a factor of gamma", { "gamma" }, 0.5);

	//get parameters
	std::string train_csv_file = (std::string)args::get(train_csv_dir);
	std::string	val_csv_file = (std::string)args::get(val_csv_dir);
	double learning_rate = (double)args::get(lr);
	double momentum_ = (double)args::get(momentum);
	double weight_decay_ = (double)args::get(weight_decay);
	int c = (int)args::get(n_class);
	int batch = (int)args::get(batch_size);

2.读入数据

代码如下(示例):

	//process
	std::vector<double> norm_mean = { 0.406, 0.456, 0.485 };
	std::vector<double> norm_std = { 0.225, 0.224, 0.229 };
	std::tuple<float, float, float> mean = std::tuple<float, float, float>(103.939 / 255, 116.779 / 255, 123.68 / 255); //这个值求出来和{0.406,0.456, 0.485}一模一样,主要小心BGR和RGB顺序
	auto train_dataset = CamVidDataset(train_csv_file, "train", 32, true, 0.5);
		/*.map(torch::data::transforms::Normalize<>(norm_mean, norm_std))*/
	
	// Train_Data loader
	auto train_loader = torch::data::make_data_loader<torch::data::samplers::RandomSampler>(
		std::move(train_dataset), batch);

	auto val_dataset = CamVidDataset(val_csv_file, "val", c, true, 0.);

	// Val_Data loader
	auto val_loader = torch::data::make_data_loader<torch::data::samplers::RandomSampler>(
		std::move(val_dataset), batch);

3.FCN网络

//因为暂时不会通过torch::nn::Sequential获取中间层输出
//所以目前使用最为原始的方式来处理
//同时注意这里skip链接使用的是maxpool2d特征,要和其他类似网络结构进行对比,例如dense
//以下实现为vgg16
struct VGG : public torch::nn::Module
{
	VGG(bool nBN) : in_channels(3)
	{
		conv2d_1 = register_module("conv2d_1", torch::nn::Conv2d(torch::nn::Conv2dOptions(in_channels, 64, 3).padding(1)));
		relu_1 = register_module("relu_1", torch::nn::ReLU(torch::nn::ReLUOptions(true)));
		conv2d_2 = register_module("conv2d_2", torch::nn::Conv2d(torch::nn::Conv2dOptions(64, 64, 3).padding(1)));
		relu_2 = register_module("relu_2", torch::nn::ReLU(torch::nn::ReLUOptions(true)));

		conv2d_3 = register_module("conv2d_3", torch::nn::Conv2d(torch::nn::Conv2dOptions(64, 128, 3).padding(1)));
		relu_3 = register_module("relu_3", torch::nn::ReLU(torch::nn::ReLUOptions(true)));
		conv2d_4 = register_module("conv2d_4", torch::nn::Conv2d(torch::nn::Conv2dOptions(128, 128, 3).padding(1)));
		relu_4 = register_module("relu_4", torch::nn::ReLU(torch::nn::ReLUOptions(true)));

		conv2d_5 = register_module("conv2d_5", torch::nn::Conv2d(torch::nn::Conv2dOptions(128, 256, 3).padding(1)));
		relu_5 = register_module("relu_5", torch::nn::ReLU(torch::nn::ReLUOptions(true)));
		conv2d_6 = register_module("conv2d_6", torch::nn::Conv2d(torch::nn::Conv2dOptions(256, 256, 3).padding(1)));
		relu_6 = register_module("relu_6", torch::nn::ReLU(torch::nn::ReLUOptions(true)));
		conv2d_7 = register_module("conv2d_7", torch::nn::Conv2d(torch::nn::Conv2dOptions(256, 256, 3).padding(1)));
		relu_7 = register_module("relu_7", torch::nn::ReLU(torch::nn::ReLUOptions(true)));
		conv2d_8 = register_module("conv2d_8", torch::nn::Conv2d(torch::nn::Conv2dOptions(256, 256, 3).padding(1)));
		relu_8 = register_module("relu_8", torch::nn::ReLU(torch::nn::ReLUOptions(true)));

		conv2d_9 = register_module("conv2d_9", torch::nn::Conv2d(torch::nn::Conv2dOptions(256, 512, 3).padding(1)));
		relu_9 = register_module("relu_9", torch::nn::ReLU(torch::nn::ReLUOptions(true)));
		conv2d_10 = register_module("conv2d_10", torch::nn::Conv2d(torch::nn::Conv2dOptions(512, 512, 3).padding(1)));
		relu_10 = register_module("relu_10", torch::nn::ReLU(torch::nn::ReLUOptions(true)));
		conv2d_11 = register_module("conv2d_11", torch::nn::Conv2d(torch::nn::Conv2dOptions(512, 512, 3).padding(1)));
		relu_11 = register_module("relu_11", torch::nn::ReLU(torch::nn::ReLUOptions(true)));
		conv2d_12 = register_module("conv2d_12", torch::nn::Conv2d(torch::nn::Conv2dOptions(512, 512, 3).padding(1)));
		relu_12 = register_module("relu_12", torch::nn::ReLU(torch::nn::ReLUOptions(true)));

		conv2d_13 = register_module("conv2d_13", torch::nn::Conv2d(torch::nn::Conv2dOptions(512, 512, 3).padding(1)));
		relu_13 = register_module("relu_13", torch::nn::ReLU(torch::nn::ReLUOptions(true)));
		conv2d_14 = register_module("conv2d_14", torch::nn::Conv2d(torch::nn::Conv2dOptions(512, 512, 3).padding(1)));
		relu_14 = register_module("relu_14", torch::nn::ReLU(torch::nn::ReLUOptions(true)));
		conv2d_15 = register_module("conv2d_15", torch::nn::Conv2d(torch::nn::Conv2dOptions(512, 512, 3).padding(1)));
		relu_15 = register_module("relu_15", torch::nn::ReLU(torch::nn::ReLUOptions(true)));
		conv2d_16 = register_module("conv2d_16", torch::nn::Conv2d(torch::nn::Conv2dOptions(512, 512, 3).padding(1)));
		relu_16 = register_module("relu_16", torch::nn::ReLU(torch::nn::ReLUOptions(true)));
	}

	~VGG()
	{

	}

	torch::Tensor forward(torch::Tensor input)
	{
		xmaxpool_1 = torch::nn::MaxPool2d(torch::nn::MaxPool2dOptions(2).stride(2))(relu_2(conv2d_2(relu_1(conv2d_1(input)))));
		xmaxpool_2 = torch::nn::MaxPool2d(torch::nn::MaxPool2dOptions(2).stride(2))(relu_4(conv2d_4(relu_3(conv2d_3(xmaxpool_1)))));
		xmaxpool_3 = torch::nn::MaxPool2d(torch::nn::MaxPool2dOptions(2).stride(2))
			(relu_8((conv2d_8((relu_7((conv2d_7((relu_6(conv2d_6(relu_5(conv2d_5(xmaxpool_2)))))))))))));
		xmaxpool_4 = torch::nn::MaxPool2d(torch::nn::MaxPool2dOptions(2).stride(2))
			(relu_12((conv2d_12((relu_11((conv2d_11((relu_10(conv2d_10(relu_9(conv2d_9(xmaxpool_3)))))))))))));
		xmaxpool_5 = torch::nn::MaxPool2d(torch::nn::MaxPool2dOptions(2).stride(2))
			(relu_16((conv2d_16((relu_15((conv2d_15((relu_14(conv2d_14(relu_13(conv2d_13(xmaxpool_4)))))))))))));
		return xmaxpool_5;
	}
	torch::Tensor xmaxpool_1, xmaxpool_2, xmaxpool_3, xmaxpool_4, xmaxpool_5;

	torch::nn::Conv2d conv2d_1{ nullptr }, conv2d_2{ nullptr }, conv2d_3{ nullptr }, conv2d_4{ nullptr }, conv2d_5{ nullptr }, conv2d_6{ nullptr };
	torch::nn::Conv2d conv2d_7{ nullptr }, conv2d_8{ nullptr }, conv2d_9{ nullptr }, conv2d_10{ nullptr }, conv2d_11{ nullptr }, conv2d_12{ nullptr };
	torch::nn::Conv2d conv2d_13{ nullptr }, conv2d_14{ nullptr }, conv2d_15{ nullptr }, conv2d_16{ nullptr };
	torch::nn::ReLU relu_1{ nullptr }, relu_2{ nullptr }, relu_3{ nullptr }, relu_4{ nullptr }, relu_5{ nullptr }, relu_6{ nullptr };
	torch::nn::ReLU relu_7{ nullptr }, relu_8{ nullptr }, relu_9{ nullptr }, relu_10{ nullptr }, relu_11{ nullptr }, relu_12{ nullptr };
	torch::nn::ReLU relu_13{ nullptr }, relu_14{ nullptr }, relu_15{ nullptr }, relu_16{ nullptr };

private:
	int  in_channels;
};

struct FCN8s : public torch::nn::Module
{
	FCN8s(int c) :n_class(c), in_channels(3)
	{
		//vgg
		conv2d_1 = register_module("conv2d_1", torch::nn::Conv2d(torch::nn::Conv2dOptions(in_channels, 64, 3).padding(1)));
		relu_1 = register_module("relu_1", torch::nn::ReLU(torch::nn::ReLUOptions(true)));
		conv2d_2 = register_module("conv2d_2", torch::nn::Conv2d(torch::nn::Conv2dOptions(64, 64, 3).padding(1)));
		relu_2 = register_module("relu_2", torch::nn::ReLU(torch::nn::ReLUOptions(true)));
		maxpool_1 = register_module("maxpool_1", torch::nn::MaxPool2d(torch::nn::MaxPool2dOptions(2).stride(2)));

		conv2d_3 = register_module("conv2d_3", torch::nn::Conv2d(torch::nn::Conv2dOptions(64, 128, 3).padding(1)));
		relu_3 = register_module("relu_3", torch::nn::ReLU(torch::nn::ReLUOptions(true)));
		conv2d_4 = register_module("conv2d_4", torch::nn::Conv2d(torch::nn::Conv2dOptions(128, 128, 3).padding(1)));
		relu_4 = register_module("relu_4", torch::nn::ReLU(torch::nn::ReLUOptions(true)));
		maxpool_2 = register_module("maxpool_2", torch::nn::MaxPool2d(torch::nn::MaxPool2dOptions(2).stride(2)));

		conv2d_5 = register_module("conv2d_5", torch::nn::Conv2d(torch::nn::Conv2dOptions(128, 256, 3).padding(1)));
		relu_5 = register_module("relu_5", torch::nn::ReLU(torch::nn::ReLUOptions(true)));
		conv2d_6 = register_module("conv2d_6", torch::nn::Conv2d(torch::nn::Conv2dOptions(256, 256, 3).padding(1)));
		relu_6 = register_module("relu_6", torch::nn::ReLU(torch::nn::ReLUOptions(true)));
		conv2d_7 = register_module("conv2d_7", torch::nn::Conv2d(torch::nn::Conv2dOptions(256, 256, 3).padding(1)));
		relu_7 = register_module("relu_7", torch::nn::ReLU(torch::nn::ReLUOptions(true)));
		maxpool_3 = register_module("maxpool_3", torch::nn::MaxPool2d(torch::nn::MaxPool2dOptions(2).stride(2)));

		conv2d_9 = register_module("conv2d_9", torch::nn::Conv2d(torch::nn::Conv2dOptions(256, 512, 3).padding(1)));
		relu_9 = register_module("relu_9", torch::nn::ReLU(torch::nn::ReLUOptions(true)));
		conv2d_10 = register_module("conv2d_10", torch::nn::Conv2d(torch::nn::Conv2dOptions(512, 512, 3).padding(1)));
		relu_10 = register_module("relu_10", torch::nn::ReLU(torch::nn::ReLUOptions(true)));
		conv2d_11 = register_module("conv2d_11", torch::nn::Conv2d(torch::nn::Conv2dOptions(512, 512, 3).padding(1)));
		relu_11 = register_module("relu_11", torch::nn::ReLU(torch::nn::ReLUOptions(true)));
		maxpool_4 = register_module("maxpool_4", torch::nn::MaxPool2d(torch::nn::MaxPool2dOptions(2).stride(2)));

		conv2d_13 = register_module("conv2d_13", torch::nn::Conv2d(torch::nn::Conv2dOptions(512, 512, 3).padding(1)));
		relu_13 = register_module("relu_13", torch::nn::ReLU(torch::nn::ReLUOptions(true)));
		conv2d_14 = register_module("conv2d_14", torch::nn::Conv2d(torch::nn::Conv2dOptions(512, 512, 3).padding(1)));
		relu_14 = register_module("relu_14", torch::nn::ReLU(torch::nn::ReLUOptions(true)));
		conv2d_15 = register_module("conv2d_15", torch::nn::Conv2d(torch::nn::Conv2dOptions(512, 512, 3).padding(1)));
		relu_15 = register_module("relu_15", torch::nn::ReLU(torch::nn::ReLUOptions(true)));
		maxpool_5 = register_module("maxpool_5", torch::nn::MaxPool2d(torch::nn::MaxPool2dOptions(2).stride(2)));
		//fcns
		
		convTranspose2d_1 = register_module("convTranspose2d_1", torch::nn::ConvTranspose2d(torch::nn::ConvTranspose2dOptions(512, 512, 3).stride(2).padding(1).dilation(1).output_padding(1)));
		relu_fcn_1 = register_module("relu_fcn_1", torch::nn::ReLU(torch::nn::ReLUOptions(true)));
		batchNorm2d_1 = register_module("batchNorm2d_1", torch::nn::BatchNorm2d(torch::nn::BatchNorm2dOptions(512)));
		convTranspose2d_2 = register_module("convTranspose2d_2", torch::nn::ConvTranspose2d(torch::nn::ConvTranspose2dOptions(512, 256, 3).stride(2).padding(1).dilation(1).output_padding(1)));
		relu_fcn_2 = register_module("relu_fcn_2", torch::nn::ReLU(torch::nn::ReLUOptions(true)));
		batchNorm2d_2 = register_module("batchNorm2d_2", torch::nn::BatchNorm2d(torch::nn::BatchNorm2dOptions(256)));
		convTranspose2d_3 = register_module("convTranspose2d_3", torch::nn::ConvTranspose2d(torch::nn::ConvTranspose2dOptions(256, 128, 3).stride(2).padding(1).dilation(1).output_padding(1)));
		relu_fcn_3 = register_module("relu_fcn_3", torch::nn::ReLU(torch::nn::ReLUOptions(true)));
		batchNorm2d_3 = register_module("batchNorm2d_3", torch::nn::BatchNorm2d(torch::nn::BatchNorm2dOptions(128)));
		convTranspose2d_4 = register_module("convTranspose2d_4", torch::nn::ConvTranspose2d(torch::nn::ConvTranspose2dOptions(128, 64, 3).stride(2).padding(1).dilation(1).output_padding(1)));
		relu_fcn_4 = register_module("relu_fcn_4", torch::nn::ReLU(torch::nn::ReLUOptions(true)));
		batchNorm2d_4 = register_module("batchNorm2d_4", torch::nn::BatchNorm2d(torch::nn::BatchNorm2dOptions(64)));
		convTranspose2d_5 = register_module("convTranspose2d_5", torch::nn::ConvTranspose2d(torch::nn::ConvTranspose2dOptions(64, 32, 3).stride(2).padding(1).dilation(1).output_padding(1)));
		relu_fcn_5 = register_module("relu_fcn_5", torch::nn::ReLU(torch::nn::ReLUOptions(true)));
		batchNorm2d_5 = register_module("batchNorm2d_5", torch::nn::BatchNorm2d(torch::nn::BatchNorm2dOptions(32)));
		classifier = register_module("conv2d", torch::nn::Conv2d(torch::nn::Conv2dOptions(32, n_class, 1)));
	}
	~FCN8s()
	{

	}

	torch::Tensor forward(torch::Tensor input)
	{
		//vgg inference
		xmaxpool_1 = maxpool_1((relu_2(conv2d_2(relu_1(conv2d_1(input))))));
		xmaxpool_2 = maxpool_2((relu_4(conv2d_4(relu_3(conv2d_3(xmaxpool_1))))));
		xmaxpool_3 = maxpool_3(((relu_7((conv2d_7((relu_6(conv2d_6(relu_5(conv2d_5(xmaxpool_2)))))))))));
		xmaxpool_4 = maxpool_4((relu_11((conv2d_11((relu_10(conv2d_10(relu_9(conv2d_9(xmaxpool_3))))))))));
		xmaxpool_5 = maxpool_5((relu_15((conv2d_15((relu_14(conv2d_14(relu_13(conv2d_13(xmaxpool_4))))))))));
		//fcns inference
		score = relu_fcn_1(convTranspose2d_1(xmaxpool_5));
		score = score.add(xmaxpool_4);
		score = batchNorm2d_1(score);
		score = relu_fcn_2(convTranspose2d_2(score));
		score = score.add(xmaxpool_3);
		score = batchNorm2d_2(score);
		score = batchNorm2d_3(relu_fcn_3(convTranspose2d_3(score)));
		score = batchNorm2d_4(relu_fcn_4(convTranspose2d_4(score)));
		score = batchNorm2d_5(relu_fcn_5(convTranspose2d_5(score)));
		score = classifier(score);
		return score;
	}

	//fcns
	torch::Tensor score;
	int n_class;
	torch::nn::ReLU relu_fcn_1{ nullptr }, relu_fcn_2{ nullptr }, relu_fcn_3{ nullptr }, relu_fcn_4{ nullptr }, relu_fcn_5{ nullptr };
	torch::nn::ConvTranspose2d convTranspose2d_1{ nullptr }, convTranspose2d_2{ nullptr }, convTranspose2d_3{ nullptr }, convTranspose2d_4{ nullptr }, convTranspose2d_5{ nullptr };
	torch::nn::Conv2d classifier{ nullptr };
	torch::nn::BatchNorm2d batchNorm2d_1{ nullptr }, batchNorm2d_2{ nullptr }, batchNorm2d_3{ nullptr }, batchNorm2d_4{ nullptr }, batchNorm2d_5{ nullptr };

	//vgg
	int  in_channels;
	torch::Tensor xmaxpool_1, xmaxpool_2, xmaxpool_3, xmaxpool_4, xmaxpool_5;

	torch::nn::Conv2d conv2d_1{ nullptr }, conv2d_2{ nullptr }, conv2d_3{ nullptr }, conv2d_4{ nullptr }, conv2d_5{ nullptr }, conv2d_6{ nullptr };
	torch::nn::Conv2d conv2d_7{ nullptr }, conv2d_8{ nullptr }, conv2d_9{ nullptr }, conv2d_10{ nullptr }, conv2d_11{ nullptr }, conv2d_12{ nullptr };
	torch::nn::Conv2d conv2d_13{ nullptr }, conv2d_14{ nullptr }, conv2d_15{ nullptr }, conv2d_16{ nullptr };
	torch::nn::ReLU relu_1{ nullptr }, relu_2{ nullptr }, relu_3{ nullptr }, relu_4{ nullptr }, relu_5{ nullptr }, relu_6{ nullptr };
	torch::nn::ReLU relu_7{ nullptr }, relu_8{ nullptr }, relu_9{ nullptr }, relu_10{ nullptr }, relu_11{ nullptr }, relu_12{ nullptr };
	torch::nn::ReLU relu_13{ nullptr }, relu_14{ nullptr }, relu_15{ nullptr }, relu_16{ nullptr };
	torch::nn::MaxPool2d maxpool_1{ nullptr }, maxpool_2{ nullptr }, maxpool_3{ nullptr }, maxpool_4{ nullptr }, maxpool_5{ nullptr };
};

4.IOU

std::vector<float> iou(torch::Tensor pred, torch::Tensor target)
{
	pred = pred.to(torch::kCPU);
	target = target.to(torch::kCPU);
	pred = pred.toType(torch::kInt32);
	target = target.toType(torch::kInt32);
	int h = target.sizes()[0];
	int w = target.sizes()[1];
	torch::Tensor pred_inds = torch::zeros({ h, w }).toType(torch::kInt32);
	torch::Tensor target_inds = torch::zeros({ h, w }).toType(torch::kInt32);

	//将tensor转换成int array
	pred = pred.squeeze(0);
	
	//将pred和target全部转换成int 数组,因为tensor的迭代计算非常缓慢
	int* pred_array = (int*)pred.data_ptr();
	int* target_array = (int*)target.data_ptr();
	int intersection = 0;	//并集
	int union_ = 0;			//交集
	
	std::vector<float> ious;	//单张图片的n_class的ious
	//auto tm_start = std::chrono::system_clock::now();
	for (int cls = 0; cls < num_class; cls++)
	{
		intersection = 0;
		//以下两行为重新调整数组的开头,否则会因为++而导致指针往后走
		pred_array = (int*)pred.data_ptr();
		target_array = (int*)target.data_ptr();
		//以下两行为清0,每次循环计算一个分类计数
		pred_inds.zero_();
		target_inds.zero_();
		//同样的转换pred_inds为int类型数组,每次循环需要重新设置指针起始
		int* pred_inds_array = (int*)pred_inds.data_ptr();
		int* target_inds_array = (int*)target_inds.data_ptr();
		//开始计算pred_array也就是推理结果的tensor中等于cls计数,设置为1
		//同样的target_array也就是标签,未经过one-hot的标签,在CamVidUtils中已经保存在Labeled文件夹中
		for (int j = 0; j < h; j ++)
		{
			for (int k = 0; k < w; k ++, pred_inds_array++, target_inds_array++)
			{
				if (*pred_array++ == cls)
					*pred_inds_array = 1;
				if (*target_array++ == cls)
					*target_inds_array = 1;
			}
		}
		//重新把指针起始设置回来
		pred_inds_array = (int*)pred_inds.data_ptr();
		target_inds_array = (int*)target_inds.data_ptr();
		//交集的计算,即标签中等于1的对应像素值和推理结果中的对应像素值是多少,对这个值进行累加,即为并集
		//也就是推理出来的图片在当前分类的比对中多少像素点的分类和标签当前类中的分类是一样的
		for (int k = 0; k < h * w; k++, target_inds_array++, pred_inds_array++)
		{
			if (*target_inds_array == 1)
			{
				intersection += *pred_inds_array;
			}
		}
		//printf("交集 intersection = %d\n", intersection);
		//求并集
		union_ = (pred_inds.sum().item<int>() + target_inds.sum().item<int>() - intersection);
		
		//如果并集为0,当前类并没有ground truth
		if (union_ == 0)
			ious.push_back(std::nanf("nan"));
		else
			ious.push_back(float(intersection) / max(union_, 1));	//求iou,将每个类的iou推入到ious中作为函数返回
	}
	//auto tm_end = std::chrono::system_clock::now();
	//printf("cost:{%lld msec}\n", std::chrono::duration_cast<std::chrono::milliseconds>(tm_end - tm_start).count());

	return ious;
}

5.验证

//语义分割的验证过程是比较特殊的
template <typename DataLoader>
void Train::val(int nEpoch, FCN8s& fcn8s, torch::Device device, DataLoader& val_loader)
{
	fcn8s.to(device);
	fcn8s.train(false);
	fcn8s.eval();

	std::vector<float> ious;
	long long accumulationCost = 0;

	float totalMeanIoU = .0;
	float totalPixel_accs = .0;
	int N = 0;

	for (auto batch : *val_loader)
	{
		N++;
		auto data = batch.data();
		if (!data->data.numel())
		{
			std::cout << "tensor is empty!" << std::endl;
			continue;
		}
		torch::Tensor input = data->data.unsqueeze(0);
		input = input.to(device);
		torch::Tensor target = data->target.to(device);

		auto tm_start = std::chrono::system_clock::now();
		auto fcns_output = fcn8s.forward(input);
		auto tm_end = std::chrono::system_clock::now();

		accumulationCost += std::chrono::duration_cast<std::chrono::milliseconds>(tm_end - tm_start).count();
		int N = fcns_output.sizes()[0];
		int h = fcns_output.sizes()[2];
		int w = fcns_output.sizes()[3];

		//output结果为{1, 32, h, w} => {-1, num_class} 即拉成每个像素一行,这个像素点的32类概率
		//然后argmax(1)求每行最大值下标,即是求出当前像素点属于哪一类,注意这里将像素点的值变成了分类值
		//最后重新调整成{h,w}类型
		torch::Tensor pred = fcns_output.permute({ 0, 2, 3, 1 }).reshape({ -1, num_class }).argmax(1).reshape({ N, h, w });
		//iou函数将返回每张图片(如果是有batch的)在n_class中的分类iou
		//例如,[第一张图片[0类的iou, 1类的iou, .... n_class-1类iou],第二张图片[0类iou, 1类iou, ....n_class-1类iou]....batch张]
		ious = iou(pred, target);
		//因为在论文中的建议以及gpu内存的限制,在进行语义分割的时候经常使用batch = 1,因此在这里就直接累加ious中的值(vector<float>)进行一次mean即可
		//至此,求完像素类型分类的iou
		
		//注意像素点的精确accs和IoU是不一样的衡量尺度,accs很大的情况下,IoU并不一定大,IoU衡量的是图像的重合程度,accs是像素点的相等程度
		//换句话讲,图片中如果存在车辆和道路,车辆像素点都一样而道路仍旧错误的情况下,就会造成accs很大,然而整张图片重合程度仍旧很低
		float meanIoU = .0;
		float pixel_accs = .0;
		std::vector<float>::iterator it;
		for (it = ious.begin(); it != ious.end(); ++it)
		{
			if (std::isnan(*it))
				continue;
			else
				meanIoU += (*it);
		}
		meanIoU /= num_class;
		totalMeanIoU += meanIoU;
		pixel_accs = pixel_acc(pred, target);
		totalPixel_accs += pixel_accs;
		//cout << "meanIoU: " << meanIoU << " pixel_accs: " << pixel_accs << endl;
	}
	totalMeanIoU /= N;
	totalPixel_accs /= N;
	printf("epoch{%d}, pix_acc: {%0.6f}, meanIoU: {%0.6f}\n", nEpoch, totalPixel_accs, totalMeanIoU);
	//printf("epoch {%d}, meanIoU:{%0.5f} cost:{%lld msec}\n", nEpoch, meanIoU, accumulationCost);
}

总结

本系列博客均非纯理论或者初级入门类文章,具有较强工程化内容,后续将介绍音视频GB28181系统平台,最后结合音视频平台完成示范性AI落地运用。

  • 4
    点赞
  • 9
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值