libtorch---day04[MNIST数据集]

参考pytorch

数据集读取

MNIST数据集是一个广泛使用的手写数字识别数据集,包含 60,000张训练图像和10,000张测试图像。每张图像是一个 28 × 28 28\times 28 28×28像素的灰度图像,标签是一个 0到9之间的数字,表示图像中的手写数字。

MNIST 数据集的文件格式

  • 图像文件格式
    图像文件包含一系列28x28像素的灰度图像,文件的前16个字节是头信息,后面是图像数据。头信息的格式如下:

魔数(4 字节):用于标识文件类型。对于图像文件,魔数是 2051(0x00000803)。
图像数量(4 字节):表示文件中包含的图像数量。
图像高度(4 字节):表示每个图像的高度,通常为 28。
图像宽度(4 字节):表示每个图像的宽度,通常为 28。
图像数据紧随其后,每个像素占用 1 字节,按行优先顺序存储。

  • 标签文件格式
    标签文件包含一系列标签,文件的前8个字节是头信息,后面是标签数据。头信息的格式如下:

魔数(4 字节):用于标识文件类型。对于标签文件,魔数是 2049(0x00000801)。
标签数量(4 字节):表示文件中包含的标签数量。
标签数据紧随其后,每个标签占用 1 字节,表示图像的类别(0 到 9)。

std::string dataset_path = "../dataSet/MNIST";
void loadMNIST(std::vector<cv::Mat>& images, std::vector<torch::Tensor>& images_tensor, std::vector<uint8_t>& labels, torch::Tensor& train_labels, bool train_data) {
    int32_t magic_number;
    int32_t num;
    int32_t HEIGHT;
    int32_t WIDTH;

    std::string image_path = train_data ? dataset_path + "/train-images.idx3-ubyte" : dataset_path + "/t10k-images.idx3-ubyte";
    std::string label_path = train_data ? dataset_path + "/train-labels.idx1-ubyte" : dataset_path + "/t10k-labels.idx1-ubyte";

    images.clear();
    labels.clear();
    images_tensor.clear();

    std::ifstream fs;
    fs.open(image_path.c_str(), std::ios::binary);
    if (fs.is_open()) {
        fs.read(reinterpret_cast<char*>(&magic_number), sizeof(magic_number));
        magic_number = _byteswap_ulong(magic_number);
        fs.read(reinterpret_cast<char*>(&num), sizeof(num));
        num = _byteswap_ulong(num);
        fs.read(reinterpret_cast<char*>(&HEIGHT), sizeof(HEIGHT));
        HEIGHT = _byteswap_ulong(HEIGHT);
        fs.read(reinterpret_cast<char*>(&WIDTH), sizeof(WIDTH));
        WIDTH = _byteswap_ulong(WIDTH);
        printf("magic number: %d, image number: %d, image height: %d, image width: %d\n", magic_number, num, HEIGHT, WIDTH);
        for (int i = 0; i < num; i++) {
            std::vector<unsigned char> image_data;
            image_data.resize(HEIGHT * WIDTH);
            fs.read(reinterpret_cast<char*>(image_data.data()), HEIGHT * WIDTH);

            cv::Mat image_cv(HEIGHT, WIDTH, CV_8UC1, image_data.data());
            torch::Tensor image_torch = torch::from_blob(image_data.data(), { static_cast<long long>(image_data.size()) }, torch::kUInt8).clone();
            image_torch = image_torch.to(torch::kF32) / 255.;
            images_tensor.push_back(image_torch);
            images.push_back(image_cv.clone()); // 使用 clone() 确保数据独立
        }
        printf("image vector size: %d\n", int(images.size()));
        fs.close();
    }
    else {
        printf("can not open file %s\n", image_path.c_str());
        return;
    }

    fs.open(label_path.c_str(), std::ios::binary);
    if (fs.is_open()) {
        fs.read(reinterpret_cast<char*>(&magic_number), sizeof(magic_number));
        magic_number = _byteswap_ulong(magic_number);
        fs.read(reinterpret_cast<char*>(&num), sizeof(num));
        num = _byteswap_ulong(num);
        printf("magic number: %d, label number: %d\n", magic_number, num);
        labels.resize(num);
        fs.read(reinterpret_cast<char*>(labels.data()), num);
        train_labels = torch::from_blob(labels.data(), { num }, torch::kUInt8).clone();
        fs.close();
    }
    else {
        printf("can not open file %s\n", label_path.c_str());
        return;
    }
}
  • 要点
    • 在读取 MNIST 数据集时,文件中的数据是以大端序存储的,而大多数现代计算机(如 x86 架构)使用小端序。因此,在读取文件中的数据时,需要进行字节序转换,以确保数据的正确性。
    • 读取数据的时候,用for循环进行读取,如果用while(!fs.eof()),如果不额外处理,会导致多一个数据出来。
    • void数据构造Tensor*的时候,要注意调用clone方法。

全部代码

这里就一个线性层,结合交叉熵函数;

#include <torch/torch.h>
#include <fstream>
#include <opencv2/opencv.hpp>

std::string dataset_path = "../dataSet/MNIST";
void loadMNIST(std::vector<cv::Mat>& images, std::vector<torch::Tensor>& images_tensor, std::vector<uint8_t>& labels, torch::Tensor& train_labels, bool train_data) {
    int32_t magic_number;
    int32_t num;
    int32_t HEIGHT;
    int32_t WIDTH;

    std::string image_path = train_data ? dataset_path + "/train-images.idx3-ubyte" : dataset_path + "/t10k-images.idx3-ubyte";
    std::string label_path = train_data ? dataset_path + "/train-labels.idx1-ubyte" : dataset_path + "/t10k-labels.idx1-ubyte";

    images.clear();
    labels.clear();
    images_tensor.clear();

    std::ifstream fs;
    fs.open(image_path.c_str(), std::ios::binary);
    if (fs.is_open()) {
        fs.read(reinterpret_cast<char*>(&magic_number), sizeof(magic_number));
        magic_number = _byteswap_ulong(magic_number);
        fs.read(reinterpret_cast<char*>(&num), sizeof(num));
        num = _byteswap_ulong(num);
        fs.read(reinterpret_cast<char*>(&HEIGHT), sizeof(HEIGHT));
        HEIGHT = _byteswap_ulong(HEIGHT);
        fs.read(reinterpret_cast<char*>(&WIDTH), sizeof(WIDTH));
        WIDTH = _byteswap_ulong(WIDTH);
        printf("magic number: %d, image number: %d, image height: %d, image width: %d\n", magic_number, num, HEIGHT, WIDTH);
        for (int i = 0; i < num; i++) {
            std::vector<unsigned char> image_data;
            image_data.resize(HEIGHT * WIDTH);
            fs.read(reinterpret_cast<char*>(image_data.data()), HEIGHT * WIDTH);

            cv::Mat image_cv(HEIGHT, WIDTH, CV_8UC1, image_data.data());
            torch::Tensor image_torch = torch::from_blob(image_data.data(), { static_cast<long long>(image_data.size()) }, torch::kUInt8).clone();
            image_torch = image_torch.to(torch::kF32) / 255.;
            images_tensor.push_back(image_torch);
            images.push_back(image_cv.clone()); // 使用 clone() 确保数据独立
        }
        printf("image vector size: %d\n", int(images.size()));
        fs.close();
    }
    else {
        printf("can not open file %s\n", image_path.c_str());
        return;
    }

    fs.open(label_path.c_str(), std::ios::binary);
    if (fs.is_open()) {
        fs.read(reinterpret_cast<char*>(&magic_number), sizeof(magic_number));
        magic_number = _byteswap_ulong(magic_number);
        fs.read(reinterpret_cast<char*>(&num), sizeof(num));
        num = _byteswap_ulong(num);
        printf("magic number: %d, label number: %d\n", magic_number, num);
        labels.resize(num);
        fs.read(reinterpret_cast<char*>(labels.data()), num);
        train_labels = torch::from_blob(labels.data(), { num }, torch::kUInt8).clone();
        fs.close();
    }
    else {
        printf("can not open file %s\n", label_path.c_str());
        return;
    }
}
using namespace torch;

int main()
{
	std::vector<cv::Mat> images_show;
	std::vector<uint8_t> labels_train;
	std::vector<torch::Tensor> image_train;
	torch::Tensor label_train;
	 
	loadMNIST(images_show, image_train, labels_train, label_train, 0);
	torch::Tensor train_data = torch::stack(image_train);
	torch::Tensor train_data_label = label_train;
	torch::Tensor weights = torch::randn({ image_train[0].sizes()[0], 10 }).set_requires_grad(true);
	torch::Tensor bias = torch::randn({ 10 }).set_requires_grad(true);
	double lr = 1e-1;
	int iteration = 10000;
	torch::nn::CrossEntropyLoss criterion;
    torch::optim::SGD optim({weights, bias}, lr);
	for (int i = 0; i < iteration; i++)
	{
		auto predict = torch::matmul(train_data, weights) + bias;
		// std::cout << "predict data size: " << predict.sizes() << ", train_data_label data size: " << train_data_label.sizes() << std::endl;
		auto loss = criterion(predict, train_data_label);

		loss.backward();
        optim.step();
        optim.zero_grad();
        if((i+1) % 500 == 0)
		    printf("[%d /%d, loss: %lf]\n", i + 1, iteration, loss.item<double>());
	}

	loadMNIST(images_show, image_train, labels_train, label_train, 1);

	cv::Mat im_show;
	std::vector<cv::Mat> im_shows;
	for (int i = 0; i < 2; i++)
	{
		cv::Mat im_show_;
		std::vector<cv::Mat> im_shows_;
		for (int j = 0; j < 5; j++)
		{
            int index = torch::randint(0, images_show.size() - 1, {}).item<int>();
			cv::Mat im = images_show[index];
			torch::Tensor im_torch = image_train[index];
			uchar label = labels_train[index];
			auto predict = torch::matmul(im_torch, weights) + bias;
			auto label_predict = torch::argmax(predict.view({ -1 })).item<int>();
			cv::Mat im_resized;
			cv::resize(im, im_resized, cv::Size(16 * im.rows, 16 * im.cols));
			cv::cvtColor(im_resized, im, cv::COLOR_GRAY2RGB);
			cv::putText(im, "groud true: " + std::to_string(static_cast<int>(label)), cv::Point2f(40, 40), cv::FONT_HERSHEY_PLAIN, 3, cv::Scalar(0, 0, 255), 2);
			cv::putText(im, "predict true: " + std::to_string(label_predict), cv::Point2f(40, 90), cv::FONT_HERSHEY_PLAIN, 3, cv::Scalar(0, 255, 0), 2);
			im_shows_.push_back(im);
		}
		cv::hconcat(im_shows_, im_show_);
		im_shows.push_back(im_show_);
	}
	cv::vconcat(im_shows, im_show);
	try {
		/*cv::imshow("result", im_show);
		cv::waitKey(0);*/
        cv::imwrite("validation.png", im_show);
	}
	catch (cv::Exception& e)
	{
		printf("%s\n", e.what());
	}
	return 0;
}

结果

请添加图片描述

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值