libtorch-VGG cifar10分类

libtorch-VGG cifar10分类

libtorch-UNET Carvana车辆轮廓分割
libtorch-VGG cifar10分类
libtorch-FCN CamVid语义分割
libtorch-RDN DIV2K超分重建
libtorch-char-rnn-classification libtorch官网系列教程
libtorch-char-rnn-generation libtorch官网系列教程
libtorch-char-rnn-shakespeare libtorch官网系列教程
libtorch-minst libtorch官网系列教程
————————————————

                        版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。

原文链接:https://blog.csdn.net/barcaFC/article/details/140002947



前言

libtorch C++ pytorch vgg
从数据集准备、训练、推理 全部由libtorch c++完成
运行环境:
操作系统:windows 64位
opencv3.4.1
libtorch 任意版本 64位
visual studio 2017 64位编译
数据集:cifar10
参考论文:https://paperswithcode.com/method/vgg

一、VGG是什么?

是早期深度学习网络的一次重大改进,将特征提取的卷积缩小(3*3),深度加深,最后使用全链接,使得整个网络结构非常简洁,在早期深度学习图片分类效果中取得非常优秀的成绩,至今仍旧适用于大量场景。

为什么要Normalize?
在这里有两点:第一,VGG网络权重的初始化pytorch和libtorch是一样的,默认使用何凯明初始化,而何凯明初始化很特别的针对具有ReLU激活的网络2/sqrt(n)
第二,对于数据集的初始化,Normalize((img - mean) / div(std)),如果不进行Normalize则img的值是相当大,在进行反向梯度的时候multiplied by a learning rate
将会导致W很难收敛,甚至于严重摇晃导致失败
The goal of applying Feature Scaling is to make sure features are on almost the same scale so that each feature is equally important and make it easier to process by most ML algorithms.
本例子可以把Normalize去掉就可以发生明显的收敛速度对比
https://stats.stackexchange.com/questions/185853/why-do-we-need-to-normalize-the-images-before-we-put-them-into-cnn
链接的文章很好的讲述了为什么要进行Normalize

二、使用步骤

1.设置参数

代码如下(示例):

	// Device
	auto cuda_available = torch::cuda::is_available();
	torch::Device device(cuda_available ? torch::kCUDA : torch::kCPU);
	std::cout << (cuda_available ? "CUDA available. Training on GPU." : "Training on CPU.") << '\n';

	// Hyper parameters
	const int64_t num_classes = 10;
	const int64_t batch_size = 100;
	const size_t num_epochs = 900;
	const double learning_rate = 0.00001;
	const size_t learning_rate_decay_frequency = 300;  // number of epochs after which to decay the learning rate
	const double learning_rate_decay_factor = 1.0 / 2.0;//1.0 / 3.0;

2.读入数据

代码如下(示例):

	const std::string CIFAR_data_path = "../dataset/cifar10/";
	std::vector<double> norm_mean = { 0.485, 0.456, 0.406 };
	std::vector<double> norm_std = { 0.229, 0.224, 0.225 };
	// CIFAR10 custom dataset
	auto train_dataset = CIFAR10(CIFAR_data_path)
		.map(ConstantPad(4))
		.map(RandomHorizontalFlip())
		.map(RandomCrop({ 32, 32 }))
		.map(torch::data::transforms::Normalize<>(norm_mean, norm_std))
		.map(torch::data::transforms::Stack<>());

	
	// Number of samples in the training set
	auto num_train_samples = train_dataset.size().value();

	auto test_dataset = CIFAR10(CIFAR_data_path, CIFAR10::Mode::kTest)
		.map(torch::data::transforms::Stack<>());

	// Number of samples in the testset
	auto num_test_samples = test_dataset.size().value();

	// Data loader
	auto train_loader = torch::data::make_data_loader<torch::data::samplers::RandomSampler>(
		std::move(train_dataset), batch_size);

	auto test_loader = torch::data::make_data_loader<torch::data::samplers::SequentialSampler>(
		std::move(test_dataset), batch_size);

3.VGG网络

代码如下(示例):
	struct VGG : public torch::nn::Module
{
	VGG():in_channels(3)
	{
		conv2d_1 = register_module("conv2d_1", torch::nn::Conv2d(torch::nn::Conv2dOptions(in_channels, 64, 3).padding(1)));
		relu_1 = register_module("relu_1", torch::nn::ReLU(torch::nn::ReLUOptions(true)));
		conv2d_2 = register_module("conv2d_2", torch::nn::Conv2d(torch::nn::Conv2dOptions(64, 64, 3).padding(1)));
		relu_2 = register_module("relu_2", torch::nn::ReLU(torch::nn::ReLUOptions(true)));
		maxpool_1 = register_module("maxpool_1", torch::nn::MaxPool2d(torch::nn::MaxPool2dOptions(2).stride(2)));

		conv2d_3 = register_module("conv2d_3", torch::nn::Conv2d(torch::nn::Conv2dOptions(64, 128, 3).padding(1)));
		relu_3 = register_module("relu_3", torch::nn::ReLU(torch::nn::ReLUOptions(true)));
		conv2d_4 = register_module("conv2d_4", torch::nn::Conv2d(torch::nn::Conv2dOptions(128, 128, 3).padding(1)));
		relu_4 = register_module("relu_4", torch::nn::ReLU(torch::nn::ReLUOptions(true)));
		maxpool_2 = register_module("maxpool_2", torch::nn::MaxPool2d(torch::nn::MaxPool2dOptions(2).stride(2)));

		conv2d_5 = register_module("conv2d_5", torch::nn::Conv2d(torch::nn::Conv2dOptions(128, 256, 3).padding(1)));
		relu_5 = register_module("relu_5", torch::nn::ReLU(torch::nn::ReLUOptions(true)));
		conv2d_6 = register_module("conv2d_6", torch::nn::Conv2d(torch::nn::Conv2dOptions(256, 256, 3).padding(1)));
		relu_6 = register_module("relu_6", torch::nn::ReLU(torch::nn::ReLUOptions(true)));
		conv2d_7 = register_module("conv2d_7", torch::nn::Conv2d(torch::nn::Conv2dOptions(256, 256, 3).padding(1)));
		relu_7 = register_module("relu_7", torch::nn::ReLU(torch::nn::ReLUOptions(true)));
		maxpool_3 = register_module("maxpool_3", torch::nn::MaxPool2d(torch::nn::MaxPool2dOptions(2).stride(2)));

		conv2d_9 = register_module("conv2d_9", torch::nn::Conv2d(torch::nn::Conv2dOptions(256, 512, 3).padding(1)));
		relu_9 = register_module("relu_9", torch::nn::ReLU(torch::nn::ReLUOptions(true)));
		conv2d_10 = register_module("conv2d_10", torch::nn::Conv2d(torch::nn::Conv2dOptions(512, 512, 3).padding(1)));
		relu_10 = register_module("relu_10", torch::nn::ReLU(torch::nn::ReLUOptions(true)));
		conv2d_11 = register_module("conv2d_11", torch::nn::Conv2d(torch::nn::Conv2dOptions(512, 512, 3).padding(1)));
		relu_11 = register_module("relu_11", torch::nn::ReLU(torch::nn::ReLUOptions(true)));
		maxpool_4 = register_module("maxpool_4", torch::nn::MaxPool2d(torch::nn::MaxPool2dOptions(2).stride(2)));

		conv2d_13 = register_module("conv2d_13", torch::nn::Conv2d(torch::nn::Conv2dOptions(512, 512, 3).padding(1)));
		relu_13 = register_module("relu_13", torch::nn::ReLU(torch::nn::ReLUOptions(true)));
		conv2d_14 = register_module("conv2d_14", torch::nn::Conv2d(torch::nn::Conv2dOptions(512, 512, 3).padding(1)));
		relu_14 = register_module("relu_14", torch::nn::ReLU(torch::nn::ReLUOptions(true)));
		conv2d_15 = register_module("conv2d_15", torch::nn::Conv2d(torch::nn::Conv2dOptions(512, 512, 3).padding(1)));
		relu_15 = register_module("relu_15", torch::nn::ReLU(torch::nn::ReLUOptions(true)));
		maxpool_5 = register_module("maxpool_5", torch::nn::MaxPool2d(torch::nn::MaxPool2dOptions(2).stride(2)));

		//VGG 分类器,对于CIFAR10,最后10个分类
		torch::nn::Sequential classifier(
			torch::nn::Dropout(),
			torch::nn::Linear(512, 512),
			torch::nn::ReLU(torch::nn::ReLUOptions(true)),
			torch::nn::Dropout(),
			torch::nn::Linear(512, 512),
			torch::nn::ReLU(torch::nn::ReLUOptions(true)),
			torch::nn::Linear(512, 10));
		classifierModel = register_module("classifier", classifier.ptr());
	}

	torch::Tensor forward(torch::Tensor input)
	{
		//vgg inference
		xmaxpool_1 = maxpool_1((relu_2(conv2d_2(relu_1(conv2d_1(input))))));
		xmaxpool_2 = maxpool_2((relu_4(conv2d_4(relu_3(conv2d_3(xmaxpool_1))))));
		xmaxpool_3 = maxpool_3(((relu_7((conv2d_7((relu_6(conv2d_6(relu_5(conv2d_5(xmaxpool_2)))))))))));
		xmaxpool_4 = maxpool_4((relu_11((conv2d_11((relu_10(conv2d_10(relu_9(conv2d_9(xmaxpool_3))))))))));
		xmaxpool_5 = maxpool_5((relu_15((conv2d_15((relu_14(conv2d_14(relu_13(conv2d_13(xmaxpool_4))))))))));

		xmaxpool_5 = xmaxpool_5.view({ xmaxpool_5.size(0), -1 });
		x = classifierModel->forward(xmaxpool_5);

		return x;
	}

	//vgg
	int  in_channels;
	torch::Tensor x;
	torch::Tensor xmaxpool_1, xmaxpool_2, xmaxpool_3, xmaxpool_4, xmaxpool_5;

	torch::nn::Conv2d conv2d_1{ nullptr }, conv2d_2{ nullptr }, conv2d_3{ nullptr }, conv2d_4{ nullptr }, conv2d_5{ nullptr }, conv2d_6{ nullptr };
	torch::nn::Conv2d conv2d_7{ nullptr }, conv2d_8{ nullptr }, conv2d_9{ nullptr }, conv2d_10{ nullptr }, conv2d_11{ nullptr }, conv2d_12{ nullptr };
	torch::nn::Conv2d conv2d_13{ nullptr }, conv2d_14{ nullptr }, conv2d_15{ nullptr }, conv2d_16{ nullptr };
	torch::nn::ReLU relu_1{ nullptr }, relu_2{ nullptr }, relu_3{ nullptr }, relu_4{ nullptr }, relu_5{ nullptr }, relu_6{ nullptr };
	torch::nn::ReLU relu_7{ nullptr }, relu_8{ nullptr }, relu_9{ nullptr }, relu_10{ nullptr }, relu_11{ nullptr }, relu_12{ nullptr };
	torch::nn::ReLU relu_13{ nullptr }, relu_14{ nullptr }, relu_15{ nullptr }, relu_16{ nullptr };
	torch::nn::MaxPool2d maxpool_1{ nullptr }, maxpool_2{ nullptr }, maxpool_3{ nullptr }, maxpool_4{ nullptr }, maxpool_5{ nullptr };

	torch::nn::Sequential classifierModel;
};

三、总结

针对C++出身的开发人员,以及需要工程化落地开发人员,libtorch的C++版本提供了一个很好的学习平台,通过翻译各种CVPR论中的python工程能够很清晰了解和掌握深度学习各种框架中的深度问题,例如损失值摇晃无法下降,梯度消失,均值化的必要性等等,例如本文,如果没均值化框架参数将很难收敛就算VGG网络结构简单清晰。

  • 34
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值