图像集存储成MNIST数据集格式实现

有时会用到将一组图像存放成MNIST中那样的数据格式,以便于用于网络的训练和测试,如MNSIT中的测试集标签t10k-labels.idx1-ubyte和测试集图像t10k-images.idx3-ubyte,各包含了10000个样本,这里以此两个测试集为例详细说明下实现过程:

http://yann.lecun.com/exdb/mnist/  中对MNIST的数据存放格式进行了介绍,存储的数据都以大多数非英特尔处理器使用的MSB优先(高端)格式存储,英特尔处理器和其他低端机器的用户必须翻转标头的字节(All the integers in the files are stored in the MSB first(high endian) format used by most non-Intel processors. Users of Intel processors and other low-endian machines must flip the bytes of the header.)。

t10k-labels.idx1-ubyte(训练集标签train-labels.idx1-ubyte与此存放格式完全相同):第1至第4个字节存放magic number(MSB first);第5至第8个字节存放标签数即10000;从第9个字节开始,每个字节存放一个标签值(label value),标签值的范围为0到9。

此处的magic number(MSB first)是一个四个字节的整数,是一个IDX文件格式;第1,第2个字节总是0;第3个字节值表示数据的类型,如0x08表示unsigned byte;0x09表示signed byte;0x0B表示short(2 bytes);0x0C表示int(4 bytes);0x0D表示float(4 bytes);0x0E表示double(8 bytes);因为t10k-labels.idx1-ubyte中标签值范围为0到9,因此这里第3字节值为0x08;第4个字节表示向量/矩阵的维数,1表示向量,2表示矩阵等;这里的标签为一维向量,因此第4字节为0x01。t10k-labels.idx1-ubyte中的前8个字节是两个magic number。

打开t10k-labels.idx1-ubyte二进制文件,前8个字节数据是:00 00 08 01 00 00 27 10,这里需要注意的是,magic number是一个四字节int,在读或写时每次性读取4个字节,高字节在后,低字节在前,与存储时顺序不同,高字节在前,低字节在后,因此在读或写magic number时,需要做个转换,即高字节变低字节,低字节变高字节,实现见ReverseInt函数。

t10k-images.idx3-ubyte(训练集图像train-images.idx3-ubyte与此存放格式完全相同):第1至第4个字节存放magic number(MSB first);第5至第8个字节存放图像数即10000;第9至第12个字节存放每个图像的行数即高,这里为28;第13至第16个字节存放每个图像的列数即宽,这里为28;从第17个字节开始,每个字节存放一个像素值,像素值的范围为0到255,0表示背景,255表示前景,像素按行排列;每28*28个字节大小存放一幅图像数据。

此处的magic number(MSB first)是一个四个字节的整数,是一个IDX文件格式;第1,第2个字节总是0;第3个字节值表示数据的类型,如0x08表示unsigned byte;0x09表示signed byte;0x0B表示short(2 bytes);0x0C表示int(4 bytes);0x0D表示float(4 bytes);0x0E表示double(8 bytes);因为t10k-images.idx3-ubyte中图像像素值范围为0到255,因此这里第3字节值为0x08;第4个字节表示向量/矩阵的维数,1表示向量,2表示矩阵等;这里的图像可看做三维即channels*height*width,因此第4字节为0x03。t10k-images.idx3-ubyte中的前16个字节是四个magic number。打开0x03.t10k-images.idx3-ubyte二进制文件,前16个字节数据是:00 00 08 03 00 00 27 10 00 00 00 1c 00 00 00 1c。

测试代码如下:

#include "funset.hpp"
#include <iostream>
#include <fstream>
#include <vector>
#include <memory>
#include <opencv2/opencv.hpp>

// MNIST /
namespace {
int ReverseInt(int i)
{
	unsigned char ch1, ch2, ch3, ch4;
	ch1 = i & 255;
	ch2 = (i >> 8) & 255;
	ch3 = (i >> 16) & 255;
	ch4 = (i >> 24) & 255;
	return((int)ch1 << 24) + ((int)ch2 << 16) + ((int)ch3 << 8) + ch4;
}

void read_Mnist(std::string filename, std::vector<cv::Mat> &vec)
{
	std::ifstream file(filename, std::ios::binary);
	if (file.is_open()) {
		int magic_number = 0;
		int number_of_images = 0;
		int n_rows = 0;
		int n_cols = 0;
		file.read((char*)&magic_number, sizeof(magic_number));
		magic_number = ReverseInt(magic_number);
		file.read((char*)&number_of_images, sizeof(number_of_images));
		number_of_images = ReverseInt(number_of_images);
		file.read((char*)&n_rows, sizeof(n_rows));
		n_rows = ReverseInt(n_rows);
		file.read((char*)&n_cols, sizeof(n_cols));
		n_cols = ReverseInt(n_cols);

		for (int i = 0; i < number_of_images; ++i) {
			cv::Mat tp = cv::Mat::zeros(n_rows, n_cols, CV_8UC1);
			for (int r = 0; r < n_rows; ++r) {
				for (int c = 0; c < n_cols; ++c) {
					unsigned char temp = 0;
					file.read((char*)&temp, sizeof(temp));
					tp.at<uchar>(r, c) = (int)temp;
				}
			}
			vec.push_back(tp);
		}

		file.close();
	}
}

void read_Mnist_Label(std::string filename, std::vector<int> &vec)
{
	std::ifstream file(filename, std::ios::binary);
	if (file.is_open()) {
		int magic_number = 0;
		int number_of_images = 0;
		int n_rows = 0;
		int n_cols = 0;
		file.read((char*)&magic_number, sizeof(magic_number));
		magic_number = ReverseInt(magic_number);
		file.read((char*)&number_of_images, sizeof(number_of_images));
		number_of_images = ReverseInt(number_of_images);

		for (int i = 0; i < number_of_images; ++i) {
			unsigned char temp = 0;
			file.read((char*)&temp, sizeof(temp));
			vec[i] = (int)temp;
		}

		file.close();
	}
}

std::string GetImageName(int number, int arr[])
{
	std::string str1, str2;

	for (int i = 0; i < 10; i++) {
		if (number == i) {
			arr[i]++;
			str1 = std::to_string(arr[i]);

			if (arr[i] < 10) {
				str1 = "0000" + str1;
			} else if (arr[i] < 100) {
				str1 = "000" + str1;
			} else if (arr[i] < 1000) {
				str1 = "00" + str1;
			} else if (arr[i] < 10000) {
				str1 = "0" + str1;
			}

			break;
		}
	}

	str2 = std::to_string(number) + "_" + str1;

	return str2;
}

int write_images_to_file(const std::string& file_name, const std::vector<cv::Mat>& image_data,
	int magic_number, int image_number, int image_rows, int image_cols)
{
	if (image_number > image_data.size()) {
		fprintf(stderr, "Error: image_number > image_data.size(): \
			image_number: %d, image_data.size: %d", image_number, image_data.size());
		return -1;
	}

	std::ofstream file(file_name, std::ios::binary);
	if (!file.is_open()) {
		fprintf(stderr, "Error: open file fail: %s\n", file_name.c_str());
		return -1;
	}

	int tmp = ReverseInt(magic_number);
	file.write((char*)&tmp, sizeof(int));
	tmp = ReverseInt(image_number);
	file.write((char*)&tmp, sizeof(int));
	tmp = ReverseInt(image_rows);
	file.write((char*)&tmp, sizeof(int));
	tmp = ReverseInt(image_cols);
	file.write((char*)&tmp, sizeof(int));

	int size = image_rows * image_cols;
	for (int i = 0; i < image_number; ++i) {
		file.write((char*)image_data[i].data, sizeof(unsigned char) * size);
	}

	file.close();
	return 0;
}

int write_labels_to_file(const std::string& file_name, const std::vector<int>& label_data,
	int magic_number, int label_number)
{
	if (label_number > label_data.size()) {
		fprintf(stderr, "Error: label_number > label_data.size(): \
			label_number: %d, label_data.size: %d", label_number, label_data.size());
		return -1;
	}

	std::ofstream file(file_name, std::ios::binary);
	if (!file.is_open()) {
		fprintf(stderr, "Error: open file fail: %s\n", file_name.c_str());
		return -1;
	}

	int tmp = ReverseInt(magic_number);
	file.write((char*)&tmp, sizeof(int));
	tmp = ReverseInt(label_number);
	file.write((char*)&tmp, sizeof(int));

	std::unique_ptr<unsigned char[]> labels(new unsigned char[label_number]);
	for (int i = 0; i < label_number; ++i) {
		labels[i] = static_cast<unsigned char>(label_data[i]);
	}
	file.write((char*)labels.get(), sizeof(unsigned char) * label_number);

	file.close();
	return 0;
}
} // namespace //mnist

int ImageToMNIST()
{
	// read images
#ifdef _MSC_VER
	std::string filename_test_images = "E:/GitCode/NN_Test/data/database/MNIST/t10k-images.idx3-ubyte";
#else
	std::string filename_test_images = "data/database/MNIST/t10k-images.idx3-ubyte";
#endif
	const int number_of_test_images = 10000;
	std::vector<cv::Mat> vec_test_images;

	read_Mnist(filename_test_images, vec_test_images);
	if (vec_test_images.size() != number_of_test_images) {
		fprintf(stderr, "Error: fail to parse t10k-images.idx3-ubyte file: %d\n", vec_test_images.size());
		return -1;
	}

	// read labels
#ifdef _MSC_VER
	std::string filename_test_labels = "E:/GitCode/NN_Test/data/database/MNIST/t10k-labels.idx1-ubyte";
#else
	std::string filename_test_labels = "data/database/MNIST/t10k-labels.idx1-ubyte";
#endif
	std::vector<int> vec_test_labels(number_of_test_images);

	read_Mnist_Label(filename_test_labels, vec_test_labels);

	// write images
	const int image_magic_number = 2051; // 0x00000803
	const int image_number = 10000;
	const int image_rows = 28;
	const int image_cols = 28;
#ifdef _MSC_VER
	const std::string images_save_file_name = "E:/GitCode/NN_Test/data/new_t10k-images.idx3-ubyte";
#else
	const std::string images_save_file_name = "data/new_t10k-images.idx3-ubyte";
#endif

	if (write_images_to_file(images_save_file_name, vec_test_images, image_magic_number,
		image_number, image_rows, image_cols) != 0) {
		fprintf(stderr, "Error: write images to file fail\n");
		return -1;
	}

	// write labels
	const int label_magic_number = 2049; // 0x00000801
	const int label_number = 10000;
#ifdef _MSC_VER
	const std::string labels_save_file_name = "E:/GitCode/NN_Test/data/new_t10k-labels.idx1-ubyte";
#else
	const std::string labels_save_file_name = "data/new_t10k-labels.idx1-ubyte";
#endif

	if (write_labels_to_file(labels_save_file_name, vec_test_labels, label_magic_number, label_number) != 0) {
		fprintf(stderr, "Error: write labels to file fail\n");
		return -1;
	}

	return 0;
}

新生成的两个数据文件为new_t10k-labels.idx1-ubyte和new_t10k-images.idx3-ubyte,通过md5可知,新生成的文件与原始文件完全相同,结果如下:

GitHub: https://github.com/fengbingchun/NN_Test 

  • 3
    点赞
  • 15
    收藏
    觉得还不错? 一键收藏
  • 6
    评论
评论 6
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值