有时会用到将一组图像存放成MNIST中那样的数据格式,以便于用于网络的训练和测试,如MNSIT中的测试集标签t10k-labels.idx1-ubyte和测试集图像t10k-images.idx3-ubyte,各包含了10000个样本,这里以此两个测试集为例详细说明下实现过程:
在http://yann.lecun.com/exdb/mnist/ 中对MNIST的数据存放格式进行了介绍,存储的数据都以大多数非英特尔处理器使用的MSB优先(高端)格式存储,英特尔处理器和其他低端机器的用户必须翻转标头的字节(All the integers in the files are stored in the MSB first(high endian) format used by most non-Intel processors. Users of Intel processors and other low-endian machines must flip the bytes of the header.)。
t10k-labels.idx1-ubyte(训练集标签train-labels.idx1-ubyte与此存放格式完全相同):第1至第4个字节存放magic number(MSB first);第5至第8个字节存放标签数即10000;从第9个字节开始,每个字节存放一个标签值(label value),标签值的范围为0到9。
此处的magic number(MSB first)是一个四个字节的整数,是一个IDX文件格式;第1,第2个字节总是0;第3个字节值表示数据的类型,如0x08表示unsigned byte;0x09表示signed byte;0x0B表示short(2 bytes);0x0C表示int(4 bytes);0x0D表示float(4 bytes);0x0E表示double(8 bytes);因为t10k-labels.idx1-ubyte中标签值范围为0到9,因此这里第3字节值为0x08;第4个字节表示向量/矩阵的维数,1表示向量,2表示矩阵等;这里的标签为一维向量,因此第4字节为0x01。t10k-labels.idx1-ubyte中的前8个字节是两个magic number。
打开t10k-labels.idx1-ubyte二进制文件,前8个字节数据是:00 00 08 01 00 00 27 10,这里需要注意的是,magic number是一个四字节int,在读或写时每次性读取4个字节,高字节在后,低字节在前,与存储时顺序不同,高字节在前,低字节在后,因此在读或写magic number时,需要做个转换,即高字节变低字节,低字节变高字节,实现见ReverseInt函数。
t10k-images.idx3-ubyte(训练集图像train-images.idx3-ubyte与此存放格式完全相同):第1至第4个字节存放magic number(MSB first);第5至第8个字节存放图像数即10000;第9至第12个字节存放每个图像的行数即高,这里为28;第13至第16个字节存放每个图像的列数即宽,这里为28;从第17个字节开始,每个字节存放一个像素值,像素值的范围为0到255,0表示背景,255表示前景,像素按行排列;每28*28个字节大小存放一幅图像数据。
此处的magic number(MSB first)是一个四个字节的整数,是一个IDX文件格式;第1,第2个字节总是0;第3个字节值表示数据的类型,如0x08表示unsigned byte;0x09表示signed byte;0x0B表示short(2 bytes);0x0C表示int(4 bytes);0x0D表示float(4 bytes);0x0E表示double(8 bytes);因为t10k-images.idx3-ubyte中图像像素值范围为0到255,因此这里第3字节值为0x08;第4个字节表示向量/矩阵的维数,1表示向量,2表示矩阵等;这里的图像可看做三维即channels*height*width,因此第4字节为0x03。t10k-images.idx3-ubyte中的前16个字节是四个magic number。打开0x03.t10k-images.idx3-ubyte二进制文件,前16个字节数据是:00 00 08 03 00 00 27 10 00 00 00 1c 00 00 00 1c。
测试代码如下:
#include "funset.hpp"
#include <iostream>
#include <fstream>
#include <vector>
#include <memory>
#include <opencv2/opencv.hpp>
// MNIST /
namespace {
int ReverseInt(int i)
{
unsigned char ch1, ch2, ch3, ch4;
ch1 = i & 255;
ch2 = (i >> 8) & 255;
ch3 = (i >> 16) & 255;
ch4 = (i >> 24) & 255;
return((int)ch1 << 24) + ((int)ch2 << 16) + ((int)ch3 << 8) + ch4;
}
void read_Mnist(std::string filename, std::vector<cv::Mat> &vec)
{
std::ifstream file(filename, std::ios::binary);
if (file.is_open()) {
int magic_number = 0;
int number_of_images = 0;
int n_rows = 0;
int n_cols = 0;
file.read((char*)&magic_number, sizeof(magic_number));
magic_number = ReverseInt(magic_number);
file.read((char*)&number_of_images, sizeof(number_of_images));
number_of_images = ReverseInt(number_of_images);
file.read((char*)&n_rows, sizeof(n_rows));
n_rows = ReverseInt(n_rows);
file.read((char*)&n_cols, sizeof(n_cols));
n_cols = ReverseInt(n_cols);
for (int i = 0; i < number_of_images; ++i) {
cv::Mat tp = cv::Mat::zeros(n_rows, n_cols, CV_8UC1);
for (int r = 0; r < n_rows; ++r) {
for (int c = 0; c < n_cols; ++c) {
unsigned char temp = 0;
file.read((char*)&temp, sizeof(temp));
tp.at<uchar>(r, c) = (int)temp;
}
}
vec.push_back(tp);
}
file.close();
}
}
void read_Mnist_Label(std::string filename, std::vector<int> &vec)
{
std::ifstream file(filename, std::ios::binary);
if (file.is_open()) {
int magic_number = 0;
int number_of_images = 0;
int n_rows = 0;
int n_cols = 0;
file.read((char*)&magic_number, sizeof(magic_number));
magic_number = ReverseInt(magic_number);
file.read((char*)&number_of_images, sizeof(number_of_images));
number_of_images = ReverseInt(number_of_images);
for (int i = 0; i < number_of_images; ++i) {
unsigned char temp = 0;
file.read((char*)&temp, sizeof(temp));
vec[i] = (int)temp;
}
file.close();
}
}
std::string GetImageName(int number, int arr[])
{
std::string str1, str2;
for (int i = 0; i < 10; i++) {
if (number == i) {
arr[i]++;
str1 = std::to_string(arr[i]);
if (arr[i] < 10) {
str1 = "0000" + str1;
} else if (arr[i] < 100) {
str1 = "000" + str1;
} else if (arr[i] < 1000) {
str1 = "00" + str1;
} else if (arr[i] < 10000) {
str1 = "0" + str1;
}
break;
}
}
str2 = std::to_string(number) + "_" + str1;
return str2;
}
int write_images_to_file(const std::string& file_name, const std::vector<cv::Mat>& image_data,
int magic_number, int image_number, int image_rows, int image_cols)
{
if (image_number > image_data.size()) {
fprintf(stderr, "Error: image_number > image_data.size(): \
image_number: %d, image_data.size: %d", image_number, image_data.size());
return -1;
}
std::ofstream file(file_name, std::ios::binary);
if (!file.is_open()) {
fprintf(stderr, "Error: open file fail: %s\n", file_name.c_str());
return -1;
}
int tmp = ReverseInt(magic_number);
file.write((char*)&tmp, sizeof(int));
tmp = ReverseInt(image_number);
file.write((char*)&tmp, sizeof(int));
tmp = ReverseInt(image_rows);
file.write((char*)&tmp, sizeof(int));
tmp = ReverseInt(image_cols);
file.write((char*)&tmp, sizeof(int));
int size = image_rows * image_cols;
for (int i = 0; i < image_number; ++i) {
file.write((char*)image_data[i].data, sizeof(unsigned char) * size);
}
file.close();
return 0;
}
int write_labels_to_file(const std::string& file_name, const std::vector<int>& label_data,
int magic_number, int label_number)
{
if (label_number > label_data.size()) {
fprintf(stderr, "Error: label_number > label_data.size(): \
label_number: %d, label_data.size: %d", label_number, label_data.size());
return -1;
}
std::ofstream file(file_name, std::ios::binary);
if (!file.is_open()) {
fprintf(stderr, "Error: open file fail: %s\n", file_name.c_str());
return -1;
}
int tmp = ReverseInt(magic_number);
file.write((char*)&tmp, sizeof(int));
tmp = ReverseInt(label_number);
file.write((char*)&tmp, sizeof(int));
std::unique_ptr<unsigned char[]> labels(new unsigned char[label_number]);
for (int i = 0; i < label_number; ++i) {
labels[i] = static_cast<unsigned char>(label_data[i]);
}
file.write((char*)labels.get(), sizeof(unsigned char) * label_number);
file.close();
return 0;
}
} // namespace //mnist
int ImageToMNIST()
{
// read images
#ifdef _MSC_VER
std::string filename_test_images = "E:/GitCode/NN_Test/data/database/MNIST/t10k-images.idx3-ubyte";
#else
std::string filename_test_images = "data/database/MNIST/t10k-images.idx3-ubyte";
#endif
const int number_of_test_images = 10000;
std::vector<cv::Mat> vec_test_images;
read_Mnist(filename_test_images, vec_test_images);
if (vec_test_images.size() != number_of_test_images) {
fprintf(stderr, "Error: fail to parse t10k-images.idx3-ubyte file: %d\n", vec_test_images.size());
return -1;
}
// read labels
#ifdef _MSC_VER
std::string filename_test_labels = "E:/GitCode/NN_Test/data/database/MNIST/t10k-labels.idx1-ubyte";
#else
std::string filename_test_labels = "data/database/MNIST/t10k-labels.idx1-ubyte";
#endif
std::vector<int> vec_test_labels(number_of_test_images);
read_Mnist_Label(filename_test_labels, vec_test_labels);
// write images
const int image_magic_number = 2051; // 0x00000803
const int image_number = 10000;
const int image_rows = 28;
const int image_cols = 28;
#ifdef _MSC_VER
const std::string images_save_file_name = "E:/GitCode/NN_Test/data/new_t10k-images.idx3-ubyte";
#else
const std::string images_save_file_name = "data/new_t10k-images.idx3-ubyte";
#endif
if (write_images_to_file(images_save_file_name, vec_test_images, image_magic_number,
image_number, image_rows, image_cols) != 0) {
fprintf(stderr, "Error: write images to file fail\n");
return -1;
}
// write labels
const int label_magic_number = 2049; // 0x00000801
const int label_number = 10000;
#ifdef _MSC_VER
const std::string labels_save_file_name = "E:/GitCode/NN_Test/data/new_t10k-labels.idx1-ubyte";
#else
const std::string labels_save_file_name = "data/new_t10k-labels.idx1-ubyte";
#endif
if (write_labels_to_file(labels_save_file_name, vec_test_labels, label_magic_number, label_number) != 0) {
fprintf(stderr, "Error: write labels to file fail\n");
return -1;
}
return 0;
}
新生成的两个数据文件为new_t10k-labels.idx1-ubyte和new_t10k-images.idx3-ubyte,通过md5可知,新生成的文件与原始文件完全相同,结果如下: