MNIST数据集训练

MNIST数据集训练

下载数据集
cd data/mnist/
./get_mnist.sh

get_mnist.sh该脚本用于下载MNIST数据集并解压

原始数据集包括四个文件

  • train-images-idx3-ubyte 训练集,图片
  • train-labels-idx1-ubyte 训练集,标签
  • t10k-images-idx3-ubyte 测试集,图片
  • t10k-labels-idx1-ubyte 测试集,标签

数据集转换为图像
#include <gflags/gflags.h>
#include <glog/logging.h>

#include <stdint.h>
#include <sys/stat.h>

#include <fstream>  // NOLINT(readability/streams)
#include <string>

#include "opencv2/core/core.hpp"  
#include "opencv2/highgui/highgui.hpp"  
#include "opencv2/imgproc/imgproc.hpp" 

using std::string;

DEFINE_int32(rows, 25, "The rows of index in image");
DEFINE_int32(cols, 40, "The cols of index in image");
DEFINE_int32(offset, 0, "The offset of index in raw image");

//大端模式小端模式转换
uint32_t swap_endian(uint32_t val) {
    val = ((val << 8) & 0xFF00FF00) | ((val >> 8) & 0xFF00FF);
    return (val << 16) | (val >> 16);
}

//数据集转换函数,输入参数:MNIST数据集文件,图片文件
void convert_image(const char* image_filename, const char* png_filename) {
  // Open files
  std::ifstream image_file(image_filename, std::ios::in | std::ios::binary);
  CHECK(image_file) << "Unable to open file " << image_filename;
  // Read the magic and the meta data
  uint32_t magic;
  uint32_t num_items;
  uint32_t rows;
  uint32_t cols;

  //读取魔数
  image_file.read(reinterpret_cast<char*>(&magic), 4);
  magic = swap_endian(magic);
  CHECK_EQ(magic, 2051) << "Incorrect image file magic.";
  //读取数据条目总数
  image_file.read(reinterpret_cast<char*>(&num_items), 4);
  num_items = swap_endian(num_items);
  //读取行数
  image_file.read(reinterpret_cast<char*>(&rows), 4);
  rows = swap_endian(rows);
  //读取列数
  image_file.read(reinterpret_cast<char*>(&cols), 4);
  cols = swap_endian(cols);

  //命令行参数读取
  const int flag_rows = FLAGS_rows;
  const int flag_cols = FLAGS_cols;
  const int offset = FLAGS_offset;
  const int width = flag_cols*cols;
  const int height = flag_rows*rows;

  char* pixels = new char[rows * cols];
  cv::Mat tp = cv::Mat::zeros(height, width, CV_8UC1);
  //使用读取MINST数据,写入到opencv中的Mat类对象中
  image_file.seekg(offset*rows*cols, std::ios::cur);
  for(int i=0; i<flag_rows; i++) {
    for(int j=0; j<flag_cols; j++) {
      if(!image_file.eof()) {
        image_file.read(pixels, rows * cols);
        for(int k=0; k<rows; k++) {
          for(int l=0; l<cols; l++) {
            tp.at<uchar>(k + i*rows, j*cols + l) = (int)pixels[k*cols+l]; 
          }
        }
      }
      else {
        for(int k=0; k<rows; k++) {
          for(int l=0; l<cols; l++) {
            tp.at<uchar>(k + i*rows, j*cols + l) = 0; 
          }
        }
      }
    }
  }
  //调用opencv中的函数保存图片
  cv::imwrite(png_filename, tp); 
}

int main(int argc, char** argv) {
#ifndef GFLAGS_GFLAGS_H_
  namespace gflags = google;
#endif

  FLAGS_alsologtostderr = 1;
  // 设设置命令行参数帮助信息
  gflags::SetUsageMessage("This script converts the MNIST dataset to\n"
        "image(png) format.\n"
        "Usage:\n"
        "    convert_mnist_data [FLAGS] input_image_file "
        "output_png_file\n"
        "The MNIST dataset could be downloaded at\n"
        "    http://yann.lecun.com/exdb/mnist/\n"
        "You should gunzip them after downloading,"
        "or directly use data/mnist/get_mnist.sh\n");
  gflags::ParseCommandLineFlags(&argc, &argv, true);

  if (argc != 3) {
    gflags::ShowUsageWithFlagsRestrict(argv[0],
        "examples/mnist/convert_mnist_data");
  } else {
    google::InitGoogleLogging(argv[0]);
    //转换图片
    convert_image(argv[1], argv[2]);
  }
  return 0;
}
转换格式

下载到的原始数据集为二进制文件,需要转换为LEVELDB或LMDB才能被caffe识别

所以需要运行脚本

./examples/mnist/create_mnist.sh

此时在examples/mnist里生成了mnist_train_lmdb和mnist_test_lmdb两个目录,每个目录下都有data.mdb和lock.mdb

训练超参数
examples/mnist/train_lenet.sh

使用CPU模式运行

打印训练超参数文件examples/mnist/lenet_solver.prototxt

脚本中有指定CNN网络描述文件

解析CNN网络描述文件中的网络参数,创建训练网络

训练mnist

产生两个输出,data为图片数据,label为标签数据

打开训练lmdb,累计增加

创建中间层

最后一层loss

创建测试网络

添加accuracy

迭代次数增加,loss下降

获得最终loss值和accuracy值

用训练好的模型对数据进行预测
./build/tools/caffe.bin test \

-model examples/mnist/lenet_train_test.prototxt \

-weights examples/mnist/lenet_iter_10000.caffemodel \

-iterations 100
mnist样本字库的图片转换
import numpy as np
import struct 
import matplotlib.pyplot as plt
import Image

filename = 't10k-images-idx3-ubyte'
binfile = open(filename, 'rb')
    buf = binfile.read()

    index = 0
    magic, numImages, numRows, numColumns = struct.unpack_from('>IIII', buf, index)
    index += struct.calcsize('>IIII')

    for image in range(0, numImages):
        im = struct.unpack_from('>784B', buf, index)
        index += struct.calcsize('>784B')

        im = np.array(im, dtype='uint8')
        im = im.reshape(28, 28)

        im = Image.fromarray(im)
        im.save('data/mnist/mnist_train/train_%s.bmp' % image, 'bmp')

手写脚本测试
  • 必须是256位黑白色
  • 必须是黑底白字
  • 像素大小必须是28 × 28
  • 数字在图片中间,上下左右没有太多的空白
import os
import sys
import numpy as np
import matplotlib.pyplot as plt
import caffe
caffe_root = '/home/caffe'
sys.path.insert(0, caffe_root + 'python')

MODEL_FILE = '../mnist/lenet.prototxt'
PRETRAINED = '../mnist/lenet_iter_10000.caffemodel'
IMAGE_FILE = 'demo.bmp'
input_image = caffe.io.load_image(IMAGE_FILE,color=False)
net = caffe.Classifier(MODEL_FILE,PRETRAINED)

prediction = net.predict([input_image], oversample = False)
caffe.set_mode_cpu()
print 'predicted class:',prediction[0].argmax()

图片

测试结果

  • 0
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值