MNIST libtorch实战练习
准备工作
首先下载MNIST database,http://yann.lecun.com/exdb/mnist/
下载后千万不要用winrar之类的软件解压,比如t10k-images-idx3-ubyte被解压成t10k-images.idx3-ubyte,最好到Linux环境下用tar解压。
假设解压到I:\MNIST。
训练并保存结果
定义网络
跟前面差不多,但是需要定义padding,毕竟MNIST训练和测试图像都是28x28,但是LeNet-5期望输入的图像是32x32,所以需要给卷积网络定义padding,这里是(32 - 28)/2 = 2。所以相比之前的代码,需要稍作改动:
struct LeNet5 : torch::nn::Module
{
// 可以将padding传入卷积层C1,用于将输入图像对齐为32x32
LeNet5(int arg_padding=0)
: C1(register_module("C1", torch::nn::Conv2d(torch::nn::Conv2dOptions(1, 6, 5).padding(arg_padding))))
, C3(register_module("C3", torch::nn::Conv2d(6, 16, 5)))
, F5(register_module("F5", torch::nn::Linear(16 * 5 * 5, 120)))
, F6(register_module("F6", torch::nn::Linear(120, 84)))
, OUTPUT(register_module("OUTPUT", torch::nn::Linear(84, 10)))
{
}
~LeNet5()
{
}
int64_t num_flat_features(torch::Tensor input)
{
int64_t num_features = 1;
auto sizes = input.sizes();
for (auto s:sizes) {
num_features *= s;
}
return num_features;
}
torch::Tensor forward(torch::Tensor input)
{
namespace F = torch::nn::functional;
// 2x2 Max pooling
auto x = F::max_pool2d(F::relu(C1(input)), F::MaxPool2dFuncOptions({ 2,2 }));
// 如果是方阵,则可以只使用一个数字进行定义
x = F::max_pool2d(F::relu(C3(x)), F::MaxPool2dFuncOptions(2));
x = x.view({ -1, num_flat_features(x) });
x = F::relu(F5(x));
x = F::relu(F6(x));
x = OUTPUT(x);
return x;
}
// 定义C1卷积网络的padding
int m_padding = 0;
torch::nn::Conv2d C1;
torch::nn::Conv2d C3;
torch::nn::Linear F5;
torch::nn::Linear F6;
torch::nn::Linear OUTPUT;
};
开始训练
请看如下代码:
{
tm_start = std::chrono::system_clock::now();
auto dataset = torch::data::datasets::MNIST("I:\\MNIST\\")
.map(torch::data::transforms::Normalize<>(0.5, 0.5))
.map(torch::data::transforms::Stack<>());
auto data_loader = torch::data::make_data_loader(std::move(dataset));
tm_end = std::chrono::system_clock::now();
printf("It takes %lld msec to load MNIST handwriting database.\n",
std::chrono::duration_cast<std::chrono::milliseconds>(tm_end - tm_start).count());
tm_start = std::chrono::system_clock::now();
// 输入的图像是28x28,需要设置padding为2,转化为32x32
LeNet5 net1(2);
auto criterion = torch::nn::CrossEntropyLoss();
auto optimizer = torch::optim::SGD(net1.parameters(), torch::optim::SGDOptions(0.001).momentum(0.9));
tm_end = std::chrono::system_clock::now();
printf("It takes %lld msec to prepare training handwriting.\n",
std::chrono::duration_cast<std::chrono::milliseconds>(tm_end - tm_start).count());
tm_start = std::chrono::system_clock::now();
int64_t kNumberOfEpochs = 2;
for (int64_t epoch = 1; epoch <= kNumberOfEpochs; ++epoch) {
int i = 0;
auto running_loss = 0.;
for (torch::data::Example<>& batch : *data_loader) {
auto inputs = batch.data;
auto labels = batch.target;
optimizer.zero_grad();
// 喂数据给网络
auto outputs = net1.forward(inputs);
// 通过交叉熵计算损失
auto loss = criterion(outputs, labels);
// 反馈给网络,调整权重参数进一步优化
loss.backward();
optimizer.step();
running_loss += loss.item().toFloat();
if ((i + 1) % 3000 == 0)
{
printf("[%lld, %5d] loss: %.3f\n", epoch + 1, i + 1, running_loss / 3000);
running_loss = 0.;
}
i++;
}
}
printf("Finish training!\n");
torch::serialize::OutputArchive archive;
net1.save(archive);
archive.save_to("I:\\mnist.pt");
printf("Save the training result to I:\\mnist.pt.\n");
tm_end = std::chrono::system_clock::now();
printf("It takes %lld msec to finish training handwriting!\n",
std::chrono::duration_cast<std::chrono::milliseconds>(tm_end - tm_start).count());
}
输出结果
在debug配置下,速度太慢了,最好切换到Release配置下,这样就开启优化了,但是训练还是需要一些时间,有60000个待训练的图片,在我的机器上花了几分钟:
结果还不错,经过两轮训练,loss变得比较小了。
代码解读
MNIST数据库描述参见http://yann.lecun.com/exdb/mnist/
train-images-idx3-ubyte: training set images
train-labels-idx1-ubyte: training set labels
t10k-images-idx3-ubyte: test set images
t10k-labels-idx1-ubyte: test set labels
通过
torch::data::datasets::MNIST("I:\\MNIST\\")
将会加载train-images-idx3-ubyte/train-labels-idx1-ubyte,train-image结构如下:
TRAINING SET IMAGE FILE (train-images-idx3-ubyte):
[offset] [type] [value] [description]
0000 32 bit integer 0x00000803(2051) magic number
0004 32 bit integer 60000 number of images
0008 32 bit integer 28 number of rows
0012 32 bit integer 28 number of columns
0016 unsigned byte ?? pixel
0017 unsigned byte ?? pixel
........
xxxx unsigned byte ?? pixel
Pixels are organized row-wise. Pixel values are 0 to 255. 0 means background (white), 255 means foreground (black).
并且将每个像素normalize成 [0. ~ 0.1],再通过如下语句:
.map(torch::data::transforms::Normalize<>(0.5, 0.5))
将每个像素再normalize成[-1.0 ~ 1.0]以便于处理, 具体也可以表示成如下公式:
I
ˉ
=
i
m
a
g
e
‾
/
255.0
D
ˉ
=
(
I
ˉ
−
0.5
)
/
0.5
\bar I = \overline {image}/255.0\\ \bar D = (\bar I - 0.5)/0.5
Iˉ=image/255.0Dˉ=(Iˉ−0.5)/0.5
再通过如下语句,将60000个3阶张量(1x28x28)转化为4阶张量(60000x1x28x28):
.map(torch::data::transforms::Stack<>())
对于这种用于多分类的神经网络,多用交叉熵损失函数:
auto criterion = torch::nn::CrossEntropyLoss();
如下代码用于训练,并将输出结果利用交叉熵损失函数和真实标签计算损失,并通过损失函数进行求导将信息反馈给网络,再通过随机梯度下降法(Stochastic Gradient Descent)优化器进行参数调整,从而达到训练优化和学习的目的:
// 优化器梯度归零
optimizer.zero_grad();
// 喂数据给网络
auto outputs = net1.forward(inputs);
// 通过交叉熵计算损失
auto loss = criterion(outputs, labels);
// 反馈给网络,调整权重参数进一步优化
loss.backward();
// 优化器做网络参数调整
optimizer.step();
最后训练完毕,保存训练结果,以便下次加载使用。
torch::serialize::OutputArchive archive;
net1.save(archive);
archive.save_to("I:\\mnist.pt");
加载训练结果和测试
在前面已经得到训练结果,可以用如下代码加载:
{
tm_start = std::chrono::system_clock::now();
LeNet5 net1(2);
torch::serialize::InputArchive archive;
archive.load_from("I:\\mnist.pt");
net1.load(archive);
auto dataset = torch::data::datasets::MNIST("I:\\MNIST\\", torch::data::datasets::MNIST::Mode::kTest)
.map(torch::data::transforms::Normalize<>(0.5, 0.5))
.map(torch::data::transforms::Stack<>());
auto data_loader = torch::data::make_data_loader(std::move(dataset));
int total_test_items = 0, passed_test_items = 0;
for (torch::data::Example<>& batch : *data_loader)
{
// 用训练好的网络处理测试数据
auto outputs = net1.forward(batch.data);
// 得到预测值,0 ~ 9
auto predicted = torch::max(outputs, 1);
// 获取标签数据, 0 ~ 9
auto labels = batch.target;
// 比较预测结果和实际结果,并更新统计结果
if (labels[0].item<int>() == std::get<1>(predicted).item<int>())
passed_test_items++;
total_test_items++;
//printf("label: %d.\n", labels[0].item<int>());
//printf("predicted label: %d.\n", std::get<1>(predicted).item<int>());
//std::cout << std::get<1>(predicted) << '\n';
//break;
}
tm_end = std::chrono::system_clock::now();
printf("Total test items: %d, passed test items: %d, pass rate: %.3f%%, cost %lld msec.\n",
total_test_items, passed_test_items, passed_test_items*100.f/total_test_items,
std::chrono::duration_cast<std::chrono::milliseconds>(tm_end - tm_start).count());
}
输出结果
10000张测试图片,大概8秒钟,平均每张图片识别0.8ms,还是很快的!