tiny-cnn开源库的使用(MNIST)

转自:http://blog.csdn.net/fengbingchun/article/details/50573841

tiny-cnn是一个基于CNN的开源库,它的License是BSD 3-Clause。作者也一直在维护更新,对进一步掌握CNN很有帮助,因此下面介绍下tiny-cnn在windows7 64bit vs2013的编译及使用。

1.      从https://github.com/nyanp/tiny-cnn下载源码:

Git clone https://github.com/nyanp/tiny-cnn.git  版本号为77d80a8,更新日期2016.01.22

2.      源文件中已经包含了vs2013工程,vc/tiny-cnn.sln,默认是win32的,examples/main.cpp需要OpenCV的支持,这里新建一个x64的控制台工程tiny-cnn;

3.      仿照源工程,将相应.h文件加入到新控制台工程中,新加一个test_tiny-cnn.cpp文件;

4.      将examples/mnist中test.cpp和train.cpp文件中的代码复制到test_tiny-cnn.cpp文件中;

[cpp]  view plain  copy
  1. #include <iostream>  
  2. #include <string>  
  3. #include <vector>  
  4. #include <algorithm>  
  5. #include <tiny_cnn/tiny_cnn.h>  
  6. #include <opencv2/opencv.hpp>  
  7.   
  8. using namespace tiny_cnn;  
  9. using namespace tiny_cnn::activation;  
  10.   
  11. // rescale output to 0-100  
  12. template <typename Activation>  
  13. double rescale(double x)  
  14. {  
  15.     Activation a;  
  16.     return 100.0 * (x - a.scale().first) / (a.scale().second - a.scale().first);  
  17. }  
  18.   
  19. void construct_net(network<mse, adagrad>& nn);  
  20. void train_lenet(std::string data_dir_path);  
  21. // convert tiny_cnn::image to cv::Mat and resize  
  22. cv::Mat image2mat(image<>& img);  
  23. void convert_image(const std::string& imagefilename, double minv, double maxv, int w, int h, vec_t& data);  
  24. void recognize(const std::string& dictionary, const std::string& filename, int target);  
  25.   
  26. int main()  
  27. {  
  28.     //train  
  29.     std::string data_path = "D:/Download/MNIST";  
  30.     train_lenet(data_path);  
  31.   
  32.     //test  
  33.     std::string model_path = "D:/Download/MNIST/LeNet-weights";  
  34.     std::string image_path = "D:/Download/MNIST/";  
  35.     int target[10] = { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9 };  
  36.   
  37.     for (int i = 0; i < 10; i++) {  
  38.         char ch[15];  
  39.         sprintf(ch, "%d", i);  
  40.         std::string str;  
  41.         str = std::string(ch);  
  42.         str += ".png";  
  43.         str = image_path + str;  
  44.   
  45.         recognize(model_path, str, target[i]);  
  46.     }  
  47.   
  48.     std::cout << "ok!" << std::endl;  
  49.     return 0;  
  50. }  
  51.   
  52. void train_lenet(std::string data_dir_path) {  
  53.     // specify loss-function and learning strategy  
  54.     network<mse, adagrad> nn;  
  55.   
  56.     construct_net(nn);  
  57.   
  58.     std::cout << "load models..." << std::endl;  
  59.   
  60.     // load MNIST dataset  
  61.     std::vector<label_t> train_labels, test_labels;  
  62.     std::vector<vec_t> train_images, test_images;  
  63.   
  64.     parse_mnist_labels(data_dir_path + "/train-labels.idx1-ubyte",  
  65.         &train_labels);  
  66.     parse_mnist_images(data_dir_path + "/train-images.idx3-ubyte",  
  67.         &train_images, -1.0, 1.0, 2, 2);  
  68.     parse_mnist_labels(data_dir_path + "/t10k-labels.idx1-ubyte",  
  69.         &test_labels);  
  70.     parse_mnist_images(data_dir_path + "/t10k-images.idx3-ubyte",  
  71.         &test_images, -1.0, 1.0, 2, 2);  
  72.   
  73.     std::cout << "start training" << std::endl;  
  74.   
  75.     progress_display disp(train_images.size());  
  76.     timer t;  
  77.     int minibatch_size = 10;  
  78.     int num_epochs = 30;  
  79.   
  80.     nn.optimizer().alpha *= std::sqrt(minibatch_size);  
  81.   
  82.     // create callback  
  83.     auto on_enumerate_epoch = [&](){  
  84.         std::cout << t.elapsed() << "s elapsed." << std::endl;  
  85.         tiny_cnn::result res = nn.test(test_images, test_labels);  
  86.         std::cout << res.num_success << "/" << res.num_total << std::endl;  
  87.   
  88.         disp.restart(train_images.size());  
  89.         t.restart();  
  90.     };  
  91.   
  92.     auto on_enumerate_minibatch = [&](){  
  93.         disp += minibatch_size;  
  94.     };  
  95.   
  96.     // training  
  97.     nn.train(train_images, train_labels, minibatch_size, num_epochs,  
  98.         on_enumerate_minibatch, on_enumerate_epoch);  
  99.   
  100.     std::cout << "end training." << std::endl;  
  101.   
  102.     // test and show results  
  103.     nn.test(test_images, test_labels).print_detail(std::cout);  
  104.   
  105.     // save networks  
  106.     std::ofstream ofs("D:/Download/MNIST/LeNet-weights");  
  107.     ofs << nn;  
  108. }  
  109.   
  110. void construct_net(network<mse, adagrad>& nn) {  
  111.     // connection table [Y.Lecun, 1998 Table.1]  
  112. #define O true  
  113. #define X false  
  114.     static const bool tbl[] = {  
  115.         O, X, X, X, O, O, O, X, X, O, O, O, O, X, O, O,  
  116.         O, O, X, X, X, O, O, O, X, X, O, O, O, O, X, O,  
  117.         O, O, O, X, X, X, O, O, O, X, X, O, X, O, O, O,  
  118.         X, O, O, O, X, X, O, O, O, O, X, X, O, X, O, O,  
  119.         X, X, O, O, O, X, X, O, O, O, O, X, O, O, X, O,  
  120.         X, X, X, O, O, O, X, X, O, O, O, O, X, O, O, O  
  121.     };  
  122. #undef O  
  123. #undef X  
  124.   
  125.     // construct nets  
  126.     nn << convolutional_layer<tan_h>(32, 32, 5, 1, 6)  // C1, 1@32x32-in, 6@28x28-out  
  127.         << average_pooling_layer<tan_h>(28, 28, 6, 2)   // S2, 6@28x28-in, 6@14x14-out  
  128.         << convolutional_layer<tan_h>(14, 14, 5, 6, 16,  
  129.         connection_table(tbl, 6, 16))              // C3, 6@14x14-in, 16@10x10-in  
  130.         << average_pooling_layer<tan_h>(10, 10, 16, 2)  // S4, 16@10x10-in, 16@5x5-out  
  131.         << convolutional_layer<tan_h>(5, 5, 5, 16, 120) // C5, 16@5x5-in, 120@1x1-out  
  132.         << fully_connected_layer<tan_h>(120, 10);       // F6, 120-in, 10-out  
  133. }  
  134.   
  135. void recognize(const std::string& dictionary, const std::string& filename, int target) {  
  136.     network<mse, adagrad> nn;  
  137.   
  138.     construct_net(nn);  
  139.   
  140.     // load nets  
  141.     std::ifstream ifs(dictionary.c_str());  
  142.     ifs >> nn;  
  143.   
  144.     // convert imagefile to vec_t  
  145.     vec_t data;  
  146.     convert_image(filename, -1.0, 1.0, 32, 32, data);  
  147.   
  148.     // recognize  
  149.     auto res = nn.predict(data);  
  150.     std::vector<std::pair<doubleint> > scores;  
  151.   
  152.     // sort & print top-3  
  153.     for (int i = 0; i < 10; i++)  
  154.         scores.emplace_back(rescale<tan_h>(res[i]), i);  
  155.   
  156.     std::sort(scores.begin(), scores.end(), std::greater<std::pair<doubleint>>());  
  157.   
  158.     for (int i = 0; i < 3; i++)  
  159.         std::cout << scores[i].second << "," << scores[i].first << std::endl;  
  160.   
  161.     std::cout << "the actual digit is: " << scores[0].second << ", correct digit is: "<<target<<std::endl;  
  162.   
  163.     // visualize outputs of each layer  
  164.     //for (size_t i = 0; i < nn.depth(); i++) {  
  165.     //  auto out_img = nn[i]->output_to_image();  
  166.     //  cv::imshow("layer:" + std::to_string(i), image2mat(out_img));  
  167.     //}  
  168.      visualize filter shape of first convolutional layer  
  169.     //auto weight = nn.at<convolutional_layer<tan_h>>(0).weight_to_image();  
  170.     //cv::imshow("weights:", image2mat(weight));  
  171.   
  172.     //cv::waitKey(0);  
  173. }  
  174.   
  175. // convert tiny_cnn::image to cv::Mat and resize  
  176. cv::Mat image2mat(image<>& img) {  
  177.     cv::Mat ori(img.height(), img.width(), CV_8U, &img.at(0, 0));  
  178.     cv::Mat resized;  
  179.     cv::resize(ori, resized, cv::Size(), 3, 3, cv::INTER_AREA);  
  180.     return resized;  
  181. }  
  182.   
  183. void convert_image(const std::string& imagefilename,  
  184.     double minv,  
  185.     double maxv,  
  186.     int w,  
  187.     int h,  
  188.     vec_t& data) {  
  189.     auto img = cv::imread(imagefilename, cv::IMREAD_GRAYSCALE);  
  190.     if (img.data == nullptr) return// cannot open, or it's not an image  
  191.   
  192.     cv::Mat_<uint8_t> resized;  
  193.     cv::resize(img, resized, cv::Size(w, h));  
  194.   
  195.     // mnist dataset is "white on black", so negate required  
  196.     std::transform(resized.begin(), resized.end(), std::back_inserter(data),  
  197.         [=](uint8_t c) { return (255 - c) * (maxv - minv) / 255.0 + minv; });  
  198. }  

5.      编译时会提示几个错误,解决方法是:

(1)、error C4996,解决方法:将宏_SCL_SECURE_NO_WARNINGS添加到属性的预处理器定义中;

(2)、调用for_函数时,error C2668,对重载函数的调用不明教,解决方法:将for_中的第三个参数强制转化为size_t类型;

6.      运行程序,train时,运行结果如下图所示:


7.      对生成的model进行测试,通过画图工具,每个数字生成一张图像,共10幅,如下图:


通过导入train时生成的model,对这10张图像进行识别,识别结果如下图,其中6和9被误识为5和1:


GitHub:https://github.com/fengbingchun/NN


评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值