关闭

解析mnist数据库

标签: 深度学习
2575人阅读 评论(0) 收藏 举报
分类:

Python解析数据库成图片

import struct
import numpy as np
#import matplotlib.pyplot as plt
import Image
import sys

input_path = sys.argv[1] #mnist数据库解压后的所在路径
output_path = sys.argv[2] #生成的图片所在的路径

# =====read labels=====
label_file = input_path + '/train-labels.idx1-ubyte'
label_fp = open(label_file, 'rb')
label_buf = label_fp.read()

label_index=0
label_magic, label_numImages = struct.unpack_from('>II', label_buf, label_index)
label_index += struct.calcsize('>II')
labels = struct.unpack_from('>60000B', label_buf, label_index)

# =====read train images=====
label_map = {}
train_file = input_path + '/train-images.idx3-ubyte'
train_fp = open(train_file, 'rb')
buf = train_fp.read()

index=0
magic,numImages,numRows,numColumns=struct.unpack_from('>IIII',buf,index)
index+=struct.calcsize('>IIII')
k = 0
for image in range(0,numImages):
    label = labels[k]
    if(label_map.has_key(label)):
        ids = label_map[label] + 1
        label_map[label] += 1

    else:
        label_map[label] = 0
        ids = 0
    k += 1
    if(label_map[label] > 50):
            continue
    im=struct.unpack_from('>784B',buf,index)
    index+=struct.calcsize('>784B')

    im=np.array(im,dtype='uint8')
    im=im.reshape(28,28)
    #fig=plt.figure()
    #plotwindow=fig.add_subplot(111)
    #plt.imshow(im,cmap='gray')
    #plt.show()
    im=Image.fromarray(im)
    im.save(output_path + '/%s_%s.bmp'%(label, ids),'bmp')

Matlab解析

引自:http://blog.csdn.net/wangyuquanliuli/article/details/17378317

主程序

trainImages = loadMNISTImages('train-images.idx3-ubyte');        
trainLabels = loadMNISTLabels('train-labels.idx1-ubyte');  
N = 784;  
K = 100;% can be any other value  
testImages = loadMNISTImages('t10k-images.idx3-ubyte');  
testLabels = loadMNISTLabels('t10k-labels.idx1-ubyte');  
trainLength = length(trainImages);  
testLength = length(testImages);  
testResults = linspace(0,0,length(testImages));  
compLabel = linspace(0,0,K);  
tic;  
for i=1:testLength  
    curImage = repmat(testImages(:,i),1,trainLength);  
    curImage = abs(trainImages-curImage);  
    comp=sum(curImage);  
    [sortedComp,ind] = sort(comp);  
    for j = 1:K  
        compLabel(j) = trainLabels(ind(j));  
    end  
    table = tabulate(compLabel);  
    [maxCount,idx] = max(table(:,2));  
    testResults(i) = table(idx);    

    disp(testResults(i));  
    disp(testLabels(i));  
end  
% Compute the error on the test set  
error=0;  
for i=1:testLength  
  if (testResults(i) ~= testLabels(i))  
    error=error+1;  
  end  
end  

%Print out the classification error on the test set  
error/testLength  
toc;  
disp(toc-tic);  

两个子程序

function images = loadMNISTImages(filename)
%loadMNISTImages returns a 28x28x[number of MNIST images] matrix containing
%the raw MNIST images
fp = fopen(filename, 'rb');
assert(fp ~= -1, ['Could not open ', filename, '']);
magic = fread(fp, 1, 'int32', 0, 'ieee-be');
assert(magic == 2051, ['Bad magic number in ', filename, '']);
numImages = fread(fp, 1, 'int32', 0, 'ieee-be');
numRows = fread(fp, 1, 'int32', 0, 'ieee-be');
numCols = fread(fp, 1, 'int32', 0, 'ieee-be');
images = fread(fp, inf, 'unsigned char');
images = reshape(images, numCols, numRows, numImages);
images = permute(images,[2 1 3]);
fclose(fp);
% Reshape to #pixels x #examples
images = reshape(images, size(images, 1) * size(images, 2), size(images, 3));
% Convert to double and rescale to [0,1]
images = double(images) / 255;
end
function labels = loadMNISTLabels(filename)
%loadMNISTLabels returns a [number of MNIST images]x1 matrix containing
%the labels for the MNIST images
fp = fopen(filename, 'rb');
assert(fp ~= -1, ['Could not open ', filename, '']);
magic = fread(fp, 1, 'int32', 0, 'ieee-be');
assert(magic == 2049, ['Bad magic number in ', filename, '']);
numLabels = fread(fp, 1, 'int32', 0, 'ieee-be');
labels = fread(fp, inf, 'unsigned char');
assert(size(labels,1) == numLabels, 'Mismatch in label count');
fclose(fp);
end

C++解析

引自:http://blog.csdn.net/fengbingchun/article/details/49611549

#include <iostream>
#include <fstream>

#include "opencv2/core/core.hpp"
#include "opencv2/highgui/highgui.hpp"
#include "opencv2/imgproc/imgproc.hpp"

using namespace std;

int ReverseInt(int i)
{
    unsigned char ch1, ch2, ch3, ch4;
    ch1 = i & 255;
    ch2 = (i >> 8) & 255;
    ch3 = (i >> 16) & 255;
    ch4 = (i >> 24) & 255;
    return((int) ch1 << 24) + ((int)ch2 << 16) + ((int)ch3 << 8) + ch4;
}

void read_Mnist(string filename, vector<cv::Mat> &vec)
{
    ifstream file (filename, ios::binary);
    if (file.is_open()) {
        int magic_number = 0;
        int number_of_images = 0;
        int n_rows = 0;
        int n_cols = 0;
        file.read((char*) &magic_number, sizeof(magic_number));
        magic_number = ReverseInt(magic_number);
        file.read((char*) &number_of_images,sizeof(number_of_images));
        number_of_images = ReverseInt(number_of_images);
        file.read((char*) &n_rows, sizeof(n_rows));
        n_rows = ReverseInt(n_rows);
        file.read((char*) &n_cols, sizeof(n_cols));
        n_cols = ReverseInt(n_cols);

        for(int i = 0; i < number_of_images; ++i) {
            cv::Mat tp = cv::Mat::zeros(n_rows, n_cols, CV_8UC1);
            for(int r = 0; r < n_rows; ++r) {
                for(int c = 0; c < n_cols; ++c) {
                    unsigned char temp = 0;
                    file.read((char*) &temp, sizeof(temp));
                    tp.at<uchar>(r, c) = (int) temp;
                }
            }
            vec.push_back(tp);
        }
    }
}

void read_Mnist_Label(string filename, vector<int> &vec)
{
    ifstream file (filename, ios::binary);
    if (file.is_open()) {
        int magic_number = 0;
        int number_of_images = 0;
        int n_rows = 0;
        int n_cols = 0;
        file.read((char*) &magic_number, sizeof(magic_number));
        magic_number = ReverseInt(magic_number);
        file.read((char*) &number_of_images,sizeof(number_of_images));
        number_of_images = ReverseInt(number_of_images);

        for(int i = 0; i < number_of_images; ++i) {
            unsigned char temp = 0;
            file.read((char*) &temp, sizeof(temp));
            vec[i]= (int)temp;
        }
    }
}

string GetImageName(int number, int arr[])
{
    string str1, str2;

    for (int i = 0; i < 10; i++) {
        if (number == i) {
            arr[i]++;
            char ch1[10];  
            sprintf(ch1, "%d", arr[i]);   
            str1 = std::string(ch1); 

            if (arr[i] < 10) {
                str1 = "0000" + str1;
            } else if (arr[i] < 100) {
                str1 = "000" + str1;
            } else if (arr[i] < 1000) {
                str1 = "00" + str1;
            } else if (arr[i] < 10000) {
                str1 = "0" + str1;
            }

            break;
        }
    }

    char ch2[10];
    sprintf(ch2, "%d", number);
    str2 = std::string(ch2);

    str2 = str2 + "_" + str1;

    return str2;
}

int main()
{
    //reference: http://eric-yuan.me/cpp-read-mnist/
    //test images and test labels
    //read MNIST image into OpenCV Mat vector
    string filename_test_images = "D:/Download/t10k-images-idx3-ubyte/t10k-images.idx3-ubyte";
    int number_of_test_images = 10000;
    vector<cv::Mat> vec_test_images;

    read_Mnist(filename_test_images, vec_test_images);

    //read MNIST label into int vector
    string filename_test_labels = "D:/Download/t10k-labels-idx1-ubyte/t10k-labels.idx1-ubyte";
    vector<int> vec_test_labels(number_of_test_images);

    read_Mnist_Label(filename_test_labels, vec_test_labels);

    if (vec_test_images.size() != vec_test_labels.size()) {
        cout<<"parse MNIST test file error"<<endl;
        return -1;
    }

    //save test images
    int count_digits[10];
    for (int i = 0; i < 10; i++)
        count_digits[i] = 0;

    string save_test_images_path = "D:/Download/MNIST/test_images/";

    for (int i = 0; i < vec_test_images.size(); i++) {
        int number = vec_test_labels[i];
        string image_name = GetImageName(number, count_digits);
        image_name = save_test_images_path + image_name + ".jpg";

        cv::imwrite(image_name, vec_test_images[i]);
    }

    //train images and train labels
    //read MNIST image into OpenCV Mat vector
    string filename_train_images = "D:/Download/train-images-idx3-ubyte/train-images.idx3-ubyte";
    int number_of_train_images = 60000;
    vector<cv::Mat> vec_train_images;

    read_Mnist(filename_train_images, vec_train_images);

    //read MNIST label into int vector
    string filename_train_labels = "D:/Download/train-labels-idx1-ubyte/train-labels.idx1-ubyte";
    vector<int> vec_train_labels(number_of_train_images);

    read_Mnist_Label(filename_train_labels, vec_train_labels);

    if (vec_train_images.size() != vec_train_labels.size()) {
        cout<<"parse MNIST train file error"<<endl;
        return -1;
    }

    //save train images
    for (int i = 0; i < 10; i++)
        count_digits[i] = 0;

    string save_train_images_path = "D:/Download/MNIST/train_images/";

    for (int i = 0; i < vec_train_images.size(); i++) {
        int number = vec_train_labels[i];
        string image_name = GetImageName(number, count_digits);
        image_name = save_train_images_path + image_name + ".jpg";

        cv::imwrite(image_name, vec_train_images[i]);
    }

    return 0;
}
1
0

查看评论
* 以上用户言论只代表其个人观点,不代表CSDN网站的观点或立场
    个人资料
    • 访问:289039次
    • 积分:3909
    • 等级:
    • 排名:第8059名
    • 原创:131篇
    • 转载:59篇
    • 译文:23篇
    • 评论:28条
    最新评论