test_lenet5_mnist

 

#define __WINDOWS__ 1
//#define MNIST_PATH "E:\\ebLearn\\mnist"
#include "libidx.h"
#include "libeblearn.h"
#include "libeblearntools.h"
//#include "netconf.h"
#include <iostream>
#include <stdio.h>

using namespace std;
using namespace ebl;
uint dump_count = 0;

string *gl_mnist_dir = NULL;
string *gl_mnist_errmsg = NULL;
template <typename Tnet>
void test_lenet5_mnist(string *dir, string *errmsg, double eta);
int main(int argc,char** argv)
{
	
	if (strlen(MNIST_PATH) > 0) gl_mnist_dir = new string(MNIST_PATH);
	gl_mnist_errmsg =
      new string("MNIST directory is unknown, some tests will be ignored (MNIST can be downloaded at http://yann.lecun.com/exdb/mnist/)");
	test_lenet5_mnist<float>(gl_mnist_dir, gl_mnist_errmsg, 1e-5);
	if (gl_mnist_dir) delete gl_mnist_dir;
	if (gl_mnist_errmsg) delete gl_mnist_errmsg;
	return 0;
}

template <typename Tnet>
void test_lenet5_mnist(string *dir, string *errmsg, double eta) 
{
  cout << "initializing random seed with " << fixed_init_drand() << endl;
  typedef ubyte Tdata;
  typedef ubyte Tlab;
#ifdef __GUI__
  bool display = true;
#endif
  //uint ninternals = 1;
  cout << endl;


  // load MNIST dataset
  mnist_datasource<Tnet,Tdata,Tlab>
    test_ds(dir->c_str(), false, 1000),
    train_ds(dir->c_str(), true, 2000);
  train_ds.set_balanced(true);
  train_ds.set_shuffle_passes(true);
  // set 2nd argument to true for focusing on hardest examples
  //  train_ds.set_weigh_samples(true, false, false, 0.01);
  train_ds.set_weigh_samples(false, false, false, 0.01);
  test_ds.set_epoch_show(500); // show progress every 100 samples
  train_ds.set_epoch_show(500); // show progress every 100 samples
  train_ds.ignore_correct(true);

  // create 1-of-n targets with target 1.0 for shown class, -1.0 for the rest
  idx<Tnet> targets =
    create_target_matrix<Tnet>(1+idx_max(train_ds.labels), 1.0);
  uint nclasses = targets.dim(0);

  // create the network weights, network and trainer
  idxdim dims(train_ds.sample_dims()); // get order and dimensions of sample
  ddparameter<Tnet> theparam(60000); // create trainable parameter
  lenet5<Tnet> net(theparam, 32, 32, 5, 5, 2, 2, 5, 5, 2, 2, 120,
		   nclasses, true, false, true, false);
  cout << net.describe() << endl;

  l2_energy<Tnet> energy;
  class_answer<Tnet,Tdata,Tlab> answer(nclasses);
  trainable_module<Tnet,Tdata,Tlab> trainable(energy, net, NULL, &answer);
  supervised_trainer<Tnet, ubyte, ubyte> thetrainer(trainable, theparam);

#ifdef __GUI__
  //  labeled_datasource_gui<Tnet, ubyte, ubyte> dsgui(true);
//   dsgui.display(test_ds, 10, 10);
  supervised_trainer_gui<Tnet,Tdata,Tlab> stgui;
#endif

  // a classifier-meter measures classification errors
  classifier_meter trainmeter;
  classifier_meter testmeter;

  // initialize the network weights
  forget_param_linear fgp(1, 0.5, (int) 0 /* fixed seed */);
  trainable.forget(fgp);

  // gradient parameters
  gd_param gdp(/* double eta*/ eta,
	       /* double ln */ 	0.0,
	       /* double l1 */ 	0.0,
	       /* double l2 */ 	0.0,
	       /* intg dtime */ 	0,
	       /* double iner */0.0,
	       /* double anneal_value */ 0.001,
	       /* intg anneal_period */ 2000,
	       /* double g_t*/ 	0.0);
  cout << gdp << endl;
  infer_param infp;


  // estimate second derivative on 100 iterations, using mu=0.02
  // and set individual espilons
  // //printf("computing diagonal hessian and learning rates\n");
  thetrainer.compute_diaghessian(train_ds, 100, 0.02);




#ifdef __GUI__
  if (display) {
    stgui.display_datasource(thetrainer, test_ds, infp, 10, 10);
    // TODO: this causes compilation error now
    //    stgui.display_internals(thetrainer, test_ds, infp, gdp, ninternals);
  }
#endif

  // do training iterations
  cout << "training with " << train_ds.size() << " training samples and ";
  cout << test_ds.size() << " test samples" << endl;

  thetrainer.test(train_ds, trainmeter, infp);
  thetrainer.test(test_ds, testmeter, infp);
#ifdef __GUI__
  if (display) {
    stgui.display_datasource(thetrainer, test_ds, infp, 10, 10);
    // TODO: this causes compilation error now
    //    stgui.display_internals(thetrainer, test_ds, infp, gdp, ninternals);
  }
#endif
  // this goes at about 25 examples per second on a PIIIM 800MHz
  for (int i = 0; i < 5; ++i)
  {
    cout << "__ epoch " << i + 1 << " ______________"
	 << "_________________________________________________________" << endl;
    DEBUGMEM_PRETTY("before hessian");
    //thetrainer.compute_diaghessian(train_ds, 100, 0.02);
    thetrainer.train(train_ds, trainmeter, gdp, 1, infp, 1000, 100, .02);
    // ostringstream name(""); name << "pickings_" << i+1;
    // train_ds.save_pickings(name.str().c_str());
    thetrainer.test(train_ds, trainmeter, infp);
    thetrainer.test(test_ds, testmeter, infp);

#ifdef __GUI__
    if (display) {
      //      dsgui.display_pickings(train_ds, 3, 3);
      stgui.display_datasource(thetrainer, test_ds, infp, 10, 10);
      stgui.display_datasource(thetrainer, test_ds, infp, 10, 10);
    // TODO: this causes compilation error now
      //      stgui.display_internals(thetrainer, test_ds, infp, gdp, ninternals);
    }
#endif
  }

}


 

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
好的,下面是基于Python实现LeNet5MNIST手写字体识别的步骤: 1. 数据准备:下载MNIST数据集,包括训练集和测试集。可以使用Python库`tensorflow_datasets`来下载和加载MNIST数据集。 ```python import tensorflow_datasets as tfds # 加载MNIST数据集 mnist_dataset = tfds.load(name="mnist", split=tfds.Split.TRAIN) mnist_test_dataset = tfds.load(name="mnist", split=tfds.Split.TEST) ``` 2. 数据预处理:对数据进行标准化处理,并将训练集和测试集转换为合适的格式。 ```python def preprocess_data(data): # 将数据转换为浮点数类型 data['image'] = tf.cast(data['image'], tf.float32) # 标准化处理 data['image'] = tf.divide(data['image'], 255.0) # 调整形状 data['image'] = tf.reshape(data['image'], [-1, 28, 28, 1]) # 将标签转换为One-hot编码 data['label'] = tf.one_hot(data['label'], depth=10) return data # 对训练集和测试集进行预处理 mnist_dataset = mnist_dataset.map(preprocess_data) mnist_test_dataset = mnist_test_dataset.map(preprocess_data) # 将训练集转换为可迭代的数据集 batch_size = 32 mnist_dataset = mnist_dataset.batch(batch_size) # 将测试集转换为可迭代的数据集 mnist_test_dataset = mnist_test_dataset.batch(batch_size) ``` 3. 构建LeNet5模型:使用TensorFlow构建LeNet5模型,包括卷积层、池化层和全连接层。 ```python from tensorflow.keras import layers, models # 构建LeNet5模型 model = models.Sequential([ layers.Conv2D(filters=6, kernel_size=(5, 5), activation='relu', input_shape=(28, 28, 1)), layers.MaxPooling2D(pool_size=(2, 2)), layers.Conv2D(filters=16, kernel_size=(5, 5), activation='relu'), layers.MaxPooling2D(pool_size=(2, 2)), layers.Flatten(), layers.Dense(units=120, activation='relu'), layers.Dense(units=84, activation='relu'), layers.Dense(units=10, activation='softmax') ]) ``` 4. 编译模型:定义损失函数、优化器和评估指标。 ```python # 定义损失函数、优化器和评估指标 model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy']) ``` 5. 训练模型:使用训练集训练模型,并在测试集上进行评估。 ```python epochs = 10 history = model.fit(mnist_dataset, epochs=epochs, validation_data=mnist_test_dataset) ``` 6. 可视化训练过程:使用Matplotlib库可视化训练过程。 ```python import matplotlib.pyplot as plt # 可视化训练过程 plt.plot(history.history['accuracy'], label='training accuracy') plt.plot(history.history['val_accuracy'], label='validation accuracy') plt.title('Training and Validation Accuracy') plt.xlabel('Epoch') plt.ylabel('Accuracy') plt.legend() plt.show() ``` 7. 预测结果:使用训练好的模型对新的手写数字进行预测。 ```python import numpy as np # 加载新的手写数字图片 new_image = plt.imread('new_image.png') # 将图片转换为灰度图像 new_image = np.dot(new_image[...,:3], [0.299, 0.587, 0.114]) # 调整形状 new_image = np.reshape(new_image, (1, 28, 28, 1)) # 标准化处理 new_image = new_image / 255.0 # 对新的手写数字进行预测 prediction = model.predict(new_image) # 打印预测结果 print(np.argmax(prediction)) ``` 以上就是基于Python实现LeNet5MNIST手写字体识别的步骤。
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值