C++实现CNN识别手写数字

去年(2017年)参加robomaster时做了一段视觉,为了打大符,其实就是识别手写数字,然后控制云台射击击打制定数字。因为时间有限,而且其他部分代码都是用的C++和opencv写的,所以识别手写数字这部分代码也用C++写了。不过注意,我只写了前向计算的代码,训练的代码我没写,网络是在matlab上训练了,然后自己定义了一种存储格式存在了xml文件中,然后视觉部分的程序就是读取xml文件导入CN...
摘要由CSDN通过智能技术生成

更新:

原文已经搬运至网站:https://www.link2sea.com/archives/383,后续也将在该网站进行更新。

查看博主更多文章请前往:https://www.link2sea.com/

下面是原文

原文:

去年(2017年)参加robomaster时做了一段视觉,为了打大符,其实就是识别手写数字,然后控制云台射击击打制定数字。因为时间有限,而且其他部分代码都是用的C++和opencv写的,所以识别手写数字这部分代码也用C++写了。不过注意,我只写了前向计算的代码,训练的代码我没写,网络是在matlab上训练了,然后自己定义了一种存储格式存在了xml文件中,然后视觉部分的程序就是读取xml文件导入CNN,然后只做前向计算去做分类,正确率好像97%吧,其实做的很low,凑活用吧,偶然翻到了去年写的代码,写程序加调试,反正从开始左手做视觉到做完花了1个多星期吧,最后我做的视觉并没有用上,很遗憾。现在都用tensorflow了,这些代码也用不上,写篇博客就算是纪念下吧,毕竟一个多星期的心血。

下面就包括三个文件的程序,分别是 cnn.cpp,cnn.hpp 和 CNN_para.xml 三个文件。下面分别是三个文件的代码:

首先是cnn.cpp

/* SUST 陶亚凡 */

#include "cnn.hpp"
#include <opencv2/imgproc/imgproc.hpp>
#include <opencv2/calib3d/calib3d.hpp>
#include <opencv2/features2d/features2d.hpp>

#include <iostream>

using namespace cv;
using namespace std;


Mat CNN::cnnff(Mat x){
    cnn_t & net = cnn;
    int n = net.layers.size();
	net.layers[0].a[0] = x;
	int inputmaps = 1;
    char str[30] = {0};
	
	Mat z;
	Mat temp;
    Mat show;

	for(int l=1; l<n; l++){
        if(net.layers[l].type == 'c'){
            for(int j=0; j<net.layers[l].outputmaps; j++){
                int bias = (net.layers[l].kernelsize - 1)/2;
                Size insize;
                insize.height = net.layers[l - 1].a[0].rows - 2*bias;
                insize.width = net.layers[l - 1].a[0].cols - 2*bias;
                filter2D(net.layers[l - 1].a[0], z, net.layers[l - 1].a[0].depth(), net.layers[l].k[j][0]);

                for(int i=1; i<inputmaps; i++){
                    filter2D(net.layers[l - 1].a[i], temp, net.layers[l - 1].a[i].depth(), net.layers[l].k[j][i]);
					z = z + temp;
                }
                z(Rect(bias, bias, insize.width, insize.height)).copyTo(z);
                net.layers[l].a[j] = sigm(z + net.layers[l].b[j]);
            } // for(int j=0; j<net.layers[l].outputmaps)
				
			inputmaps = net.layers[l].outputmaps;
        } // if(net.layers[l].type == 'c')
        else if(net.layers[l].type == 's'){

            Mat one_mat = cv::Mat::ones(net.layers[l].scale, net.layers[l].scale, CV_32F);
            one_mat = one_mat / net.layers[l].scale / net.layers[l].scale;
            for(int i=0; i<inputmaps; i++){
                filter2D(net.layers[l - 1].a[i], temp, net.layers[l - 1].a[i].depth(), one_mat);
                net.layers[l].a[i] = Mat::zeros(temp.cols, temp.rows, CV_32F);
                temp(Rect(1, 1, temp.cols-1, temp.rows-1)).copyTo(net.layers[l].a[i](Rect(0, 0, temp.cols-1, temp.rows-1)));
                resize(net.layers[l].a[i], net.layers[l].a[i], Size(temp.rows/2, temp.cols/2), 0, 0, INTER_NEAREST);
            }
		}
    }

	Size sa = net.layers[n-1].a[0].size();
    int a_size = net.layers[n-1].a.size();
    Mat fv_tmp = Mat::zeros(a_size * sa.height * sa.width, 1, CV_32F);
    for(int j=0; j<a_size; j++){
        temp = MatToColumnVector(net.layers[n-1].a[j]);
        temp.copyTo(fv_tmp.rowRange(j * sa.height * sa.width, (j+1) * sa.height * sa.width));
	}
	fv_tmp.copyTo(net.fv);
	net.o = sigm(net.ffW * net.fv + net.ffb);
//cout << net.o << endl;
//    Point min_idx;
//    Point max_idx;
//    double min;
//    double max;
//    minMaxLoc(net.o, &min, &max, &min_idx, &max_idx);

//    return max_idx.y;
    return net.o.rowRange(1, net.o.rows);
}

Mat CNN::MatToColumnVector(const cv::Mat &InputMat){
	int Rows = InputMat.rows * InputMat.cols;
    int Cols = 1;
    Mat temp;
    temp = InputMat.t();
    cv::Mat OutputColumnVector(Rows, Cols, InputMat.type(), temp.data);
    return OutputColumnVector;
}

void CNN::loadcnn(const FileStorage& fs){
    int layers = 0;
    char str[40] = {0};
    int inputmaps = 1;

    fs["layers"] >> layers;
    fs["cnn_ffW"] >> cnn.ffW;
    fs["cnn_ffb"] >> cnn.ffb;

    for(int l=0; l<layers; l++){
        one_layer_t new_layer;
        int type = 0;
        sprintf(str, "cnn_layers%d_type", l);
        fs[str] >> type;
        new_layer.type = (char)type;


        Mat a;
        if(new_layer.type == 'c'){

            sprintf(str, "cnn_layers%d_outputmaps", l);
            fs[str] >> new_layer.outputmaps;

            sprintf(str, "cnn_layers%d_kernelsize", l);
            fs[str] >> new_layer.kernelsize;

            /* read k for the c layer */
            for(int j=0; j<new_layer.outputmaps; j++){
                vector<Mat> k_j;

                for(int i=0; i<inputmaps; i++){
                    Mat k_ji;
                    sprintf(str, "cnn_layers%d_k%d%d", l, i, j);
                    fs[str] >> k_ji;
                    k_j.push_back(k_ji);
                }
                new_layer.k.push_back(k_j);
                new_layer.a.push_back(a);

                float b;
                sprintf(str, "cnn_layers%d_b%d", l, j);
                fs[str] >> b;
                new_layer.b.push_back(b);
            }

            inputmaps = new_layer.outputmaps;
        }
        else if(new_layer.type == 's'){

            sprintf(str, "cnn_layers%d_scale", l);
            fs[str] >> new_layer.scale;

            for(int i=0; i<inputmaps; i++){
                new_layer.a.push_back(a);
            }
        }
        else{
            new_layer.a.push_back(a);
        }

        cnn.layers.push_back(new_layer);
    }

/* for test */
#ifndef CNN_TEST
//    #define  CNN_TEST
#endif

#ifdef CNN_TEST
    Mat x;
    Mat show;
    fs["test_data"] >> x;

    show = x * 255;
    show.convertTo(show, CV_8U);
    transpose(show,show);
    imshow("input pic", show);

    int t1,t2;
    t1 = cv::getTickCount();
    int ans = 0;
//    ans = cnnff(x);
    cout << "ans = " << ans << endl;
    t2 = cv::getTickCount();
    cout << "Consumer-Time: " << (t2 - t1) * 1000.0 / cv::getTickFrequency() << "ms" << endl;

    waitKey(0);
#endif
}


Mat CNN::sigm(Mat P){
    exp(-P, P);
    return 1/(1 + P);
}

接着是cnn.hpp:

/* SUST 陶亚凡 */
#pragma once
#include "opencv2/core/core.hpp"
#include "opencv2/highgui/highgui.hpp"

#include <vector>
#include <utility>
using namespace cv;

class CNN {
public:
	
	typedef struct one_layar {
        char type;
        std::vector<cv::Mat> a;
        std::vector<float> b;
		int outputmaps;
		int kernelsize;
        std::vector<std::vector<cv::Mat>> k;
		int scale;
	}one_layer_t;

	typedef struct cnn{
        std::vector<one_layer_t> layers;
        cv::Mat fv;
        cv::Mat ffW;
        cv::Mat ffb;
        cv::Mat o;
	}cnn_t;

public:
    CNN();

    CNN(const std::string & filename){
        FileStorage setting_fs(filename, FileStorage::READ);
        loadcnn(setting_fs);
        setting_fs.release();
	}

    cnn_t cnnreturn(){return cnn;}
    cv::Mat cnnff(cv::Mat x);

protected:
    void loadcnn(const cv::FileStorage& fs);
    cv::Mat MatToColumnVector(const cv::Mat &InputMat);
    cv::Mat sigm(cv::Mat P);
private:
    cnn_t cnn;
};

最后是存储CNN结构的xml文件的数据,其中还包括一个例图28*28的:

<?xml version="1.0" encoding="utf-8"?>
<opencv_storage>
   <layers>5</layers>
   <cnn_ffW type_id="opencv-matrix">
      <rows>10</rows>
      <cols>192</cols>
      <dt>f</dt>
      <data>-0.568684 -0.840148 -0.454545 -1.318691 -0.542536 0.262243 -0.190369 0.188239 0.168000 0.150902 -0.776808 0.276366 0.932940 -0.352636 0.049857 -0.016501 -1.307408 0.418510 1.703576 -0.275972 0.327144 -0.138172 -0.276332 -0.283987 0.166491 -0.427241 -0.047831 -0.288850 1.384224 0.820337 -0.152349 -0.619389 -0.179057 0.390203 0.915956 0.663474 -0.632854 -1.023481 -0.569145 -0.551961 0.050373 -0.443177 -1.381474 -1.862783 1.142882 1.096406 0.173264 -1.261844 -0.127105 0.177692 0.322486 -1.689607 -1.103839 -0.474490 -0.396990 -0.919875 -1.812920 0.095724 -0.548418 -0.263181 -1.279738 0.837480 2.166042 2.395184 0.039804 -0.517603 0.383755 0.961707 0.718807 -0.990300 0.508205 0.767228 -0.249060 0.790682 0.780152 -0.349029 0.356309 0.664598 0.612107 -0.952957 -1.162191 -0.413933 -0.007629 -1.119251 -1.785648 -0.975712 -0.808158 -0.878558 -1.913150 -0.696603 -1.048091 -0.554665 -0.998564 -0.749730 0.560802 1.071287 -1.263524 -1.357791 -0.382083 -0.438742 -0.147290 -0.472937 -1.144568 1.020860 -0.386209 -0.904939 1.490753 1.583735 0.351348 -1.353244 -0.017483 1.558953 0.123887 0.572187 -0.513296 -1.274787 1.160012 0.325008 -0.453308 -1.265027 1.114152 -0.706548 -0.857605 -0.830754 0.010065 -0.484612 0.113302 0.800446 0.011397 0.450172 -0.345578 -1.867464 0.983451 -0.304089 -0.519521 -1.146774 -0.291938 -0.025131 -1.005695 -0.154603 -0.961208 -0.318354 0.332191 0.746977 0.017100 0.971496 0.350664 -0.247105 0.442959 -0.954869 -0.197612 -0.769494 -1.617853 -0.873141 -0.186129 0.185366 -1.805443 -0.410434 0.504179 -0.039253 0.344962 0.670727 0.746010 0.827250 -0.366512 -1.378792 -1.921945 -0.581498 -0.683050 -1.075891 -0.485163 -0.795280 -0.115390 -0.277790 1.091993 1.596517 -0.288662 -0.174590 0.381413 -0.776837 -0.774400 0.082224 -1.847365 -1.367115 -1.071762 -1.154800 -0.430710 0.187430 -0.981025 0.173648 0.619034 0.395178 0.670848 -0.190562 1.578060 1.081296 -0.568814 0.058157 0.686897 -0.635822 0.900271 0.439748 0.697416 -0.387801 -0.518810 0.290915 0.350912 -0.715028 -1.647098 -2.401533 -2.408404 -0.715723 -0.580959 0.128428 -0.738845 -0.865536 -0.780717 -0.422889 -1.387967 0.936520 0.654036 -0.850674 0.481915 -0.885330 0.151825 -0.003216 0.113889 0.895744 -0.334037 -1.075032 0.259564 -0.049494 -0.467224 -2.171169 0.044801 0.352682 0.169473 0.207688 1.986277 1.080322 0.031921 -0.064943 0.291739 -0.384199 0.443816 -0.678896 -0.149150 -0.103646 0.606457 -0.851627 0.666114 0.470407 0.326996 0.171576 0.866813 -1.323896 -0.896076 2.009523 1.870687 -0.088985 0.150642 0.732542 -0.415214 0.096963 -1.214229 -1.938633 0.411621 -1.070952 0.057377 -2.841435 -0.297376 0.690975 -0.864447 -0.928141 -0.510805 -0.883791 -0.459018 -0.859915 -1.351250 -1.154162 -0.298049 -2.317568 -0.768941 -0.665157 -1.651614 -0.661028 -0.192331 -0.302086 -0.063164 0.661084 0.988643 0.628787 0.061343 -1.064286 -0.399700 0.003194 -0.586815 0.845286 -0.440110 0.007938 -0.716887 1.713239 -0.470617 -0.644067 -1.051352 1.370634 0.377730 1.402900 0.152722 0.469785 -0.354833 -0.164888 0.177156 1.002260 -0.598730 -0.687314 0.602623 0.339472 0.684160 0.383501 0.670998 -0.209076 0.135982 0.173958 0.076611 1.078072 -0.646387 0.351450 0.496793 0.754224 0.023728 0.604938 1.085749 0.952455 0.258260 -0.911355 -0.461454 -0.284038 1.916496 1.390354 0.950783 1.991596 -0.027214 0.753137 0.845360 0.499624 0.201375 1.444512 0.830598 0.161795 0.984455 0.709851 0.102322 0.832743 -0.137935 -2.126368 -1.093201 0.316729 -0.345943 -0.667056 -0.668533 0.954747 -0.592794 -0.622690 1.710437 1.086969 -0.121517 -1.767624 -0.193522 -1.277201 -2.141669 -0.783888 0.709102 -0.584764 -1.178761 -1.046521 -0.500537 -1.407760 -1.456903 -1.166710 0.114529 -1.222344 -0.485937 0.366948 0.678354 0.881663 -1.885476 -1.764005 -0.574188 0.999854 -0.375238 -1.371269 0.284578 -1.017808 0.163411 0.149505 1.415138 1.614423 2.244284 1.782228 -1.215948 0.959955 -0.236855 -2.091445 -0.425496 0.373386 -1.719937 -1.513815 -0.480159 0.293449 -1.128259 -0.109912 -1.335769 -0.414058 2.265155 3.415650 -0.462177 1.326707 0.814925 0.393905 0.579491 0.613182 -0.108630 -0.614807 -1.861581 -1.632298 0.396614 -0.919935 -1.208531 0.535476 1.987387 1.474315 0.702240 0.918989 -0.236851 -1.880555 1.675656 0.526113 -0.359203 1.523022 1.762864 0.422372 0.828221 0.258063 0.032520 2.631290 0.444756 -0.210713 -1.560551 -0.438164 -0.160917 -1.351426 -0.961034 -1.900673 -0.934203 -0.572908 0.472002 0.265572 0.399855 -0.221858 0.620770 2.063923 1.105910 0.108550 0.081312 0.092552 -0.589873 -1.328607 2.411575 0.019260 -1.319895 -0.496505 -1.256033 -1.149169 -0.300088 0.803819 -1.205085 0.152804 -0.190873 1.999035 0.262617 -1.265177 -1.926570 -0.850561 -0.071587 -0.322788 -0.481424 -0.532940 -0.287171 -1.247944 1.320629 0.277887 0.361195 -1.152884 -2.887138 -1.550886 -0.365399 0.083547 -0.735211 -2.018912 0.055647 -1.070096 -0.807210 -0.741435 1.793538 0.538075 -0.096737 1.292373 1.647268 0.869284 0.188351 0.191427 0.180810 0.663445 -0.998189 -1.756560 0.431852 1.402694 1.247088 0.235649 0.025021 -1.903438 -0.063016 0.435442 -1.837954 -0.968299 -1.118530 -0.990334 0.519120 -0.820571 -0.407835 -0.238509 0.314300 1.376042 -0.030988 -0.271011 0.902175 0.178550 -0.537286 -0.722177 -2.392766 -0.244671 0.928415 0.731318 -1.323398 -1.728493 0.730880 -0.846633 1.481762 1.500963 2.334629 0.828116 0.365176 -1.337337 1.031938 0.956734 0.798765 -0.938663 0.000748 -0.134341 0.025908 0.644007 0.996969 -0.548298 2.046839 0.998907 -0.701453 -0.971052 1.675005 0.871147 -0.244738 -1.120503 -0.323273 0.596829 0.669091 0.954019 -0.860371 -1.386022 -0.392221 0.243332 -0.241974 -0.683338 -0.147636 -0.069763 0.539730 -0.502065 0.660576 0.272680 -0.973756 -2.302124 -1.806688 -1.397432 1.978154 -0.381901 -3.991838 -3.267560 -0.813706 -0.035162 0.358718 0.067427 -1.571896 0.745802 0.883409 -0.832795 -1.015582 -0.339173 -1.485606 -1.993328 0.446727 0.261968 -0.165605 0.148824 -2.105331 -2.298861 0.111934 -0.327436 -0.391297 -0.006948 -0.276133 0.201642 0.957007 0.301777 -0.292803 -1.836007 1
  • 6
    点赞
  • 46
    收藏
    觉得还不错? 一键收藏
  • 4
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值