libtorch学习笔记(5)- MNIST实战

15 篇文章 0 订阅
14 篇文章 11 订阅

MNIST libtorch实战练习

准备工作

首先下载MNIST database,http://yann.lecun.com/exdb/mnist/
下载后千万不要用winrar之类的软件解压,比如t10k-images-idx3-ubyte被解压成t10k-images.idx3-ubyte,最好到Linux环境下用tar解压。
假设解压到I:\MNIST。

训练并保存结果

定义网络

跟前面差不多,但是需要定义padding,毕竟MNIST训练和测试图像都是28x28,但是LeNet-5期望输入的图像是32x32,所以需要给卷积网络定义padding,这里是(32 - 28)/2 = 2。所以相比之前的代码,需要稍作改动:

struct LeNet5 : torch::nn::Module
{
	// 可以将padding传入卷积层C1,用于将输入图像对齐为32x32
	LeNet5(int arg_padding=0)
		: C1(register_module("C1", torch::nn::Conv2d(torch::nn::Conv2dOptions(1, 6, 5).padding(arg_padding))))
		, C3(register_module("C3", torch::nn::Conv2d(6, 16, 5)))
		, F5(register_module("F5", torch::nn::Linear(16 * 5 * 5, 120)))
		, F6(register_module("F6", torch::nn::Linear(120, 84)))
		, OUTPUT(register_module("OUTPUT", torch::nn::Linear(84, 10)))
	{
	}

	~LeNet5()
	{
	}

	int64_t num_flat_features(torch::Tensor input)
	{
		int64_t num_features = 1;
		auto sizes = input.sizes();
		for (auto s:sizes) {
			num_features *= s;
		}
		return num_features;
	}

	torch::Tensor forward(torch::Tensor input)
	{
		namespace F = torch::nn::functional;
		// 2x2 Max pooling
		auto x = F::max_pool2d(F::relu(C1(input)), F::MaxPool2dFuncOptions({ 2,2 }));
		// 如果是方阵,则可以只使用一个数字进行定义
		x = F::max_pool2d(F::relu(C3(x)), F::MaxPool2dFuncOptions(2));
		x = x.view({ -1, num_flat_features(x) });
		x = F::relu(F5(x));
		x = F::relu(F6(x));
		x = OUTPUT(x);
		return x;
	}

	// 定义C1卷积网络的padding
	int m_padding = 0;
	torch::nn::Conv2d	C1;
	torch::nn::Conv2d	C3;
	torch::nn::Linear	F5;
	torch::nn::Linear	F6;
	torch::nn::Linear	OUTPUT;
};

开始训练

请看如下代码:

	{
		tm_start = std::chrono::system_clock::now();
		auto dataset = torch::data::datasets::MNIST("I:\\MNIST\\")
			.map(torch::data::transforms::Normalize<>(0.5, 0.5))
			.map(torch::data::transforms::Stack<>());
		auto data_loader = torch::data::make_data_loader(std::move(dataset));

		tm_end = std::chrono::system_clock::now();

		printf("It takes %lld msec to load MNIST handwriting database.\n", 
			std::chrono::duration_cast<std::chrono::milliseconds>(tm_end - tm_start).count());

		tm_start = std::chrono::system_clock::now();
		// 输入的图像是28x28,需要设置padding为2,转化为32x32
		LeNet5 net1(2);

		auto criterion = torch::nn::CrossEntropyLoss();
		auto optimizer = torch::optim::SGD(net1.parameters(), torch::optim::SGDOptions(0.001).momentum(0.9));
		tm_end = std::chrono::system_clock::now();
		printf("It takes %lld msec to prepare training handwriting.\n",
			std::chrono::duration_cast<std::chrono::milliseconds>(tm_end - tm_start).count());

		tm_start = std::chrono::system_clock::now();
		int64_t kNumberOfEpochs = 2;
		for (int64_t epoch = 1; epoch <= kNumberOfEpochs; ++epoch) {

			int i = 0;
			auto running_loss = 0.;
			for (torch::data::Example<>& batch : *data_loader) {

				auto inputs = batch.data;
				auto labels = batch.target;

				optimizer.zero_grad();
				// 喂数据给网络
				auto outputs = net1.forward(inputs);
				// 通过交叉熵计算损失
				auto loss = criterion(outputs, labels);
				// 反馈给网络,调整权重参数进一步优化
				loss.backward();
				optimizer.step();

				running_loss += loss.item().toFloat();
				if ((i + 1) % 3000 == 0)
				{
					printf("[%lld, %5d] loss: %.3f\n", epoch + 1, i + 1, running_loss / 3000);
					running_loss = 0.;
				}

				i++;
			}
		}

		printf("Finish training!\n");
		torch::serialize::OutputArchive archive;
		net1.save(archive);
		archive.save_to("I:\\mnist.pt");
		printf("Save the training result to I:\\mnist.pt.\n");

		tm_end = std::chrono::system_clock::now();
		printf("It takes %lld msec to finish training handwriting!\n", 
			std::chrono::duration_cast<std::chrono::milliseconds>(tm_end - tm_start).count());
	}

输出结果

在debug配置下,速度太慢了,最好切换到Release配置下,这样就开启优化了,但是训练还是需要一些时间,有60000个待训练的图片,在我的机器上花了几分钟:
在这里插入图片描述
结果还不错,经过两轮训练,loss变得比较小了。

代码解读

MNIST数据库描述参见http://yann.lecun.com/exdb/mnist/
train-images-idx3-ubyte: training set images
train-labels-idx1-ubyte: training set labels
t10k-images-idx3-ubyte: test set images
t10k-labels-idx1-ubyte: test set labels

通过

torch::data::datasets::MNIST("I:\\MNIST\\")

将会加载train-images-idx3-ubyte/train-labels-idx1-ubyte,train-image结构如下:

TRAINING SET IMAGE FILE (train-images-idx3-ubyte):
[offset] [type]          [value]          [description]
0000     32 bit integer  0x00000803(2051) magic number
0004     32 bit integer  60000            number of images
0008     32 bit integer  28               number of rows
0012     32 bit integer  28               number of columns
0016     unsigned byte   ??               pixel
0017     unsigned byte   ??               pixel
........
xxxx     unsigned byte   ??               pixel
Pixels are organized row-wise. Pixel values are 0 to 255. 0 means background (white), 255 means foreground (black).

并且将每个像素normalize成 [0. ~ 0.1],再通过如下语句:

	.map(torch::data::transforms::Normalize<>(0.5, 0.5))

将每个像素再normalize成[-1.0 ~ 1.0]以便于处理, 具体也可以表示成如下公式:
I ˉ = i m a g e ‾ / 255.0 D ˉ = ( I ˉ − 0.5 ) / 0.5 \bar I = \overline {image}/255.0\\ \bar D = (\bar I - 0.5)/0.5 Iˉ=image/255.0Dˉ=(Iˉ0.5)/0.5
再通过如下语句,将60000个3阶张量(1x28x28)转化为4阶张量(60000x1x28x28):

.map(torch::data::transforms::Stack<>())

对于这种用于多分类的神经网络,多用交叉熵损失函数:

auto criterion = torch::nn::CrossEntropyLoss();

如下代码用于训练,并将输出结果利用交叉熵损失函数和真实标签计算损失,并通过损失函数进行求导将信息反馈给网络,再通过随机梯度下降法(Stochastic Gradient Descent)优化器进行参数调整,从而达到训练优化和学习的目的:

	// 优化器梯度归零
	optimizer.zero_grad();
	// 喂数据给网络
	auto outputs = net1.forward(inputs);
	// 通过交叉熵计算损失
	auto loss = criterion(outputs, labels);
	// 反馈给网络,调整权重参数进一步优化
	loss.backward();
	// 优化器做网络参数调整
	optimizer.step();

最后训练完毕,保存训练结果,以便下次加载使用。

	torch::serialize::OutputArchive archive;
	net1.save(archive);
	archive.save_to("I:\\mnist.pt");

加载训练结果和测试

在前面已经得到训练结果,可以用如下代码加载:

	{
		tm_start = std::chrono::system_clock::now();
		LeNet5 net1(2);
		torch::serialize::InputArchive archive;
		archive.load_from("I:\\mnist.pt");

		net1.load(archive);

		auto dataset = torch::data::datasets::MNIST("I:\\MNIST\\", torch::data::datasets::MNIST::Mode::kTest)
			.map(torch::data::transforms::Normalize<>(0.5, 0.5))
			.map(torch::data::transforms::Stack<>());
		auto data_loader = torch::data::make_data_loader(std::move(dataset));

		int total_test_items = 0, passed_test_items = 0;
		for (torch::data::Example<>& batch : *data_loader)
		{
			// 用训练好的网络处理测试数据
			auto outputs = net1.forward(batch.data);
			// 得到预测值,0 ~ 9
			auto predicted = torch::max(outputs, 1);
			// 获取标签数据, 0 ~ 9
			auto labels = batch.target;
			// 比较预测结果和实际结果,并更新统计结果
			if (labels[0].item<int>() == std::get<1>(predicted).item<int>())
				passed_test_items++;

			total_test_items++;

			//printf("label: %d.\n", labels[0].item<int>());
			//printf("predicted label: %d.\n", std::get<1>(predicted).item<int>());
			//std::cout << std::get<1>(predicted) << '\n';

			//break;
		}
		tm_end = std::chrono::system_clock::now();
		
		printf("Total test items: %d, passed test items: %d, pass rate: %.3f%%, cost %lld msec.\n", 
			total_test_items, passed_test_items, passed_test_items*100.f/total_test_items,
			std::chrono::duration_cast<std::chrono::milliseconds>(tm_end - tm_start).count());
	}

输出结果

10000张测试图片,大概8秒钟,平均每张图片识别0.8ms,还是很快的!
在这里插入图片描述

  • 4
    点赞
  • 9
    收藏
    觉得还不错? 一键收藏
  • 2
    评论
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值