基于tensorflow实现Android手写数字识别

前段时间训练了mnist手写数字识别的模型,学习后将其移植到Android端
我是参考的大佬https://puke3615.github.io/2017/08/02/Run-Mnist-On-Android/https://github.com/wangtianrui/TFonAndroid的源码,有需要的的朋友可以去下载,这里是对他写的代码的分析和我自己的理解
注解ButterKnife学习:https://www.jianshu.com/p/952c6f5e8157

implementation 'com.jakewharton:butterknife:8.8.1'

手机上效果为:
在这里插入图片描述
在这里插入图片描述

移植到Android时要添加依赖文件:libandroid_tensorflow_inference_java.jar,和编译后的TensoFlow的so库,libtensorflow_inference.so,将其添加在lib文件夹中:
在这里插入图片描述

接下来将训练好的pb模型放入assets文件夹中
在这里插入图片描述

在build.gradle文件中添加:这个可以支持在手机中调试

testInstrumentationRunner "android.support.test.runner.AndroidJUnitRunner"

这里分享一下运行程序时遇到的坑:
出现了问题“Android-Device supports x86,but APK only supports armeabi-v7a,armeabi,x86_64”,使用模拟器不能运行,因为之前添加了支持tensorflow的so库和jar包,后来我在build文件中添加

   multiDexEnabled true
        ndk {
            abiFilters "armeabi-v7a"
        }

在这里插入图片描述
也还是没有用,后来看到大佬的代码仿佛才明白了一些东西。。。
在build中增加:

    sourceSets {
        main {
            jniLibs.srcDirs = ['libs']
        }
    }

下面是实现过程:

定义mnist分类器:
private final String MODEL_PATH = "file:///android_asset/mnist.pb";//加载模型
    public static final String INPUT_NAME = "input";//对应训练模型占位符x-input
    public static final String KEEP_PROB_NAME = "keep_prob";
    public static final String OUTPUT_NAME = "output";//训练模型时的占位符y_
//注意训练模型时一定要对值和标签设置name值,导入Android后要喂数据
//tensorflow依赖文件的类
    private TensorFlowInferenceInterface inference;

//图片像素28*28
    private final int width = 28;
    private final int heifht = 28;
    private float[] inputs = new float[width * heifht];
    private int[] INPUT_SHAPE = new int[]{1, width * heifht};

//AssetManager :提供低级别的访问应用资源的API
//不同模型框架,训练模型输入的占位符不同,一定要一一对应;
//训练的数据集一定要对齐,resize与采用模型框架图像的大小一致,Android端调用接口,输入参数一定要和训练图像一致,否则会出现分类错误。
    public MnistClassifier(AssetManager assetManager) {
        this.inference = new TensorFlowInferenceInterface(assetManager, MODEL_PATH);//传入模型的路径
        //模型使用阶段, 不需要进行dropout处理, 所以keep_prob直接为1.0
        //dropout层:keep_prob训练时为0.5,测试时为1
        inference.feed(KEEP_PROB_NAME, new float[]{1.0f}, 1);
    }

    public float[] getResult(float[] inputs) {
        try {
            this.inputs = inputs;
        } catch (Exception e) {
            e.printStackTrace();
        }
        //输出结果是十个数字的概率
        float[] output = new float[10];
//填入Input数据
        inference.feed(INPUT_NAME, inputs, 1, width * heifht);
        //运行结果, 类似Python中的sess.run([outputs])
        inference.run(new String[]{OUTPUT_NAME}, false);
        inference.fetch(OUTPUT_NAME, output);
        return output;
    }
定义画板:
public class PrinterView extends View {

    //画笔
    private Paint paint;

    //用来存储“路径”
    private Path path;

    //屏幕宽
    private int width;

    public PrinterView(Context context) {
        super(context);
    }

    public PrinterView(Context context, @Nullable AttributeSet attrs) {
        super(context, attrs);
        setBackgroundColor(Color.WHITE);
        paint = new Paint();
        paint.setColor(Color.RED);
        paint.setStrokeWidth(TypedValue.applyDimension(TypedValue.COMPLEX_UNIT_DIP, 20, getResources().getDisplayMetrics()));
        paint.setStyle(Paint.Style.STROKE);
        path = new Path();

        int screenWidth = getResources().getDisplayMetrics().widthPixels;
        width = MeasureSpec.makeMeasureSpec(screenWidth, MeasureSpec.EXACTLY);
    }
    /**
    1.精确模式(MeasureSpec.EXACTLY)

在这种模式下,尺寸的值是多少,那么这个组件的长或宽就是多少。

2.最大模式(MeasureSpec.AT_MOST)

这个也就是父组件,能够给出的最大的空间,当前组件的长或宽最大只能为这么大,当然也可以比这个小。

3.未指定模式(MeasureSpec.UNSPECIFIED)

这个就是说,当前组件,可以随便用空间,不受限制。
    */

//画出手指滑动的轨迹
    @Override
    public boolean onTouchEvent(MotionEvent event) {
        float x = event.getX();
        float y = event.getY();
        switch (event.getAction()) {
            case MotionEvent.ACTION_DOWN:
                //按下
                path.moveTo(x, y);
                break;
            case MotionEvent.ACTION_MOVE:
                path.lineTo(x, y);
                break;
        }
        //刷新view
        invalidate();
        return true;
    }
    @Override
    protected void onDraw(Canvas canvas) {
        super.onDraw(canvas);
        canvas.drawPath(path, paint);
    }

    @Override
    protected void onMeasure(int widthMeasureSpec, int heightMeasureSpec) {
        //定制画板的宽和高
        super.onMeasure(width, width);
    }
    public void clean() {
        path.reset();
        invalidate();
    }
    public boolean isEmpty() {
        return path.isEmpty();
    }
//向外部提供读取画布数据的方法
/**
View组件显示的内容可以通过cache机制保存为bitmap, 使用到的api有
    void  setDrawingCacheEnabled(boolean flag),
    Bitmap  getDrawingCache(boolean autoScale),
    void  buildDrawingCache(boolean autoScale),
    void  destroyDrawingCache()
    我们要获取它的cache先要通过setDrawingCacheEnable方法把cache开启,然后再调用getDrawingCache方法就可 以获得view的cache图片了。buildDrawingCache方法可以不用调用,因为调用getDrawingCache方法时,若果 cache没有建立,系统会自动调用buildDrawingCache方法生成cache。若果要更新cache, 必须要调用destoryDrawingCache方法把旧的cache销毁,才能建立新的。
当调用setDrawingCacheEnabled方法设置为false, 系统也会自动把原来的cache销毁。
   ViewGroup在绘制子view时,而外提供了两个方法
   void setChildrenDrawingCacheEnabled(boolean enabled)
   setChildrenDrawnWithCacheEnabled(boolean enabled)
   setChildrenDrawingCacheEnabled方法可以使viewgroup里所有的子view开启cache, setChildrenDrawnWithCacheEnabled使在绘制子view时,若该子view开启了cache, 则使用它的cache进行绘制,从而节省绘制时间。
   获取cache通常会占用一定的内存,所以通常不需要的时候有必要对其进行清理,通过destroyDrawingCache或setDrawingCacheEnabled(false)实现。
*/
    public float[] getData(int width, int height) {
        float[] data = new float[height * width];
        try {
            //先让cache可以被读取(将View转化为图片都会使用cache)
            setDrawingCacheEnabled(true);
            setDrawingCacheQuality(View.DRAWING_CACHE_QUALITY_LOW);
            Bitmap cache = getDrawingCache();
            dealData(cache, data, width, height);
        } finally {
            setDrawingCacheEnabled(false);
        }
        return data;
    }

    private void dealData(Bitmap bm, float[] data, int newWidth, int newHeight) {
        //获得bitmap的宽和高
        int width = bm.getWidth();
        int height = bm.getHeight();

        //计算缩放比例
        float scaleWidth = ((float) newWidth) / width;
        float scaleHeight = ((float) newHeight) / height;
//取得想要缩放的matrix参数
        Matrix matrix = new Matrix();
        matrix.postScale(scaleWidth, scaleHeight);
        //获得目标大小的图
        Bitmap newBm = Bitmap.createBitmap(bm, 0, 0, width, height, matrix, true);

        for (int y = 0; y < newHeight; y++) {
            for (int x = 0; x < newWidth; x++) {
                //获得每个点的像素值
                int pixel = newBm.getPixel(x, y);
                data[newWidth * y + x] = pixel == 0xffffffff ? 0 : 1;
            }
        }

    }
}

识别逻辑:

    @OnClick({R.id.printer_view, R.id.result_text_view, R.id.clean_button, R.id.detect_button})
    public void onViewClicked(View view) {
        switch (view.getId()) {
            case R.id.clean_button:
                printerView.clean();
                resultTextView.setText(null);
                break;
            case R.id.detect_button:
                if (printerView.isEmpty()) {
                    resultTextView.setText("画板为空");
                    break;
                }
                MnistClassifier mnistClassifier = new MnistClassifier(getAssets());
                float[] result = mnistClassifier.getResult(printerView.getData(28, 28));
                List<MnistItem> items = new ArrayList<>(10);
                for (int i = 0; i < result.length; i++) {
                    items.add(new MnistItem(result[i], i));
                }
                Collections.sort(items);//选择概率最大的对应的值
                StringBuilder builder = new StringBuilder();
                for (int i = 0; i < 1 ; i++) {
                    MnistItem item = items.get(i);
                    builder.append((int)item.getIndex());
                }
                resultTextView.setText(builder.toString());
                break;
        }
  • 3
    点赞
  • 12
    收藏
    觉得还不错? 一键收藏
  • 8
    评论
### 回答1: 基于TensorFlow的MNIST手写数字识别是一种机器学习技术,它可以通过训练模型来识别手写数字。MNIST是一个常用的数据集,包含了大量的手写数字图像和对应的标签。TensorFlow是一个流行的深度学习框架,可以用来构建和训练神经网络模型。通过使用TensorFlow,我们可以构建一个卷积神经网络模型,对MNIST数据集进行训练和测试,从而实现手写数字识别的功能。 ### 回答2: 随着机器学习技术的不断发展,MNIST手写数字识别已成为一个基础、常见的图像分类问题。TensorFlow是目前最流行的深度学习框架之一,广泛应用于图像处理、自然语言处理等领域,所以在TensorFlow实现MNIST手写数字识别任务是非常具有代表性的。 MNIST手写数字识别是指从给定的手写数字图像中识别出数字的任务。MNIST数据集是一个由数万张手写数字图片和相应标签组成的数据集,图片都是28*28像素的灰度图像。每一张图片对应着一个标签,表示图片中所代表的数字。通过对已经标记好的图片和标签进行训练,我们将构建一个模型来预测测试集中未知图片的标签。 在TensorFlow实现MNIST手写数字识别任务,可以通过以下步骤完成: 1. 导入MNIST数据集:TensorFlow中的tf.keras.datasets模块内置了MNIST数据集,可以通过如下代码导入:(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.mnist.load_data() 2. 数据预处理:对数据进行标准化处理,即将灰度值范围从[0,255]缩放到[0,1]之间。同时将标签值进行独热编码,将每个数字的标签由一个整数转换为一个稀疏向量。采用以下代码完成数据预处理:train_images = train_images / 255.0 test_images = test_images / 255.0 train_labels = tf.keras.utils.to_categorical(train_labels, 10) test_labels = tf.keras.utils.to_categorical(test_labels, 10) 3. 构建模型:采用卷积神经网络(CNN)进行建模,包括卷积层、池化层、Dropout层和全连接层。建议采用可重复使用的模型方法tf.keras.Sequential()。具体代码实现为:model = tf.keras.Sequential([ tf.keras.layers.Conv2D(32, (3,3), activation='relu',input_shape=(28,28,1)), tf.keras.layers.MaxPooling2D((2,2)), tf.keras.layers.Flatten(), tf.keras.layers.Dropout(0.5)), tf.keras.layers.Dense(128, activation='relu'), tf.keras.layers.Dense(10, activation='softmax') ]) 4. 编译模型:指定优化器、损失函数和评估指标。可采用Adam优化器,交叉熵损失函数和准确率评估指标。具体实现代码如下:model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy']) 5. 训练模型:采用train()函数进行模型训练,完成代码如下:model.fit(train_images, train_labels, epochs=5, validation_data=(test_images, test_labels)) 6. 评估模型:计算测试准确率,完成代码如下:test_loss, test_acc = model.evaluate(test_images, test_labels) print('Test accuracy:', test_acc) 以上就是基于TensorFlow的MNIST手写数字识别的简要实现过程。其实实现过程还可以更加复杂,比如调节神经元数量,添加卷积层数量等。总之采用TensorFlow框架实现MNIST手写数字识别是一个可行的任务,未来机器学习发展趋势将越来越向深度学习方向前进。 ### 回答3: MNIST手写数字识别是计算机视觉领域中最基础的问题,使用TensorFlow实现这一问题可以帮助深入理解神经网络的原理和实现,并为其他计算机视觉任务打下基础。 首先,MNIST手写数字数据集由28x28像素的灰度图像组成,包含了数字0到9共10个类别。通过导入TensorFlow及相关库,我们可以很容易地加载MNIST数据集并可视化: ``` import tensorflow as tf import matplotlib.pyplot as plt (train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.mnist.load_data() print("Training images:", train_images.shape) print("Training labels:", train_labels.shape) print("Test images:", test_images.shape) print("Test labels:", test_labels.shape) plt.imshow(train_images[0]) plt.show() ``` 在实现MNIST手写数字识别的神经网络模型中,最常用的是卷积神经网络(Convolutional Neural Networks,CNN),主要由卷积层、激活层、池化层和全连接层等组成。卷积层主要用于提取局部特征,激活层用于引入非线性性质,池化层则用于加速处理并减少过拟合,全连接层则进行最终的分类。 以下为使用TensorFlow搭建CNN实现MNIST手写数字识别的代码: ``` model = tf.keras.Sequential([ tf.keras.layers.Conv2D(32, kernel_size=(3,3), activation='relu', input_shape=(28,28,1)), tf.keras.layers.MaxPooling2D(pool_size=(2,2)), tf.keras.layers.Conv2D(64, kernel_size=(3,3), activation='relu'), tf.keras.layers.MaxPooling2D(pool_size=(2,2)), tf.keras.layers.Flatten(), tf.keras.layers.Dense(128, activation='relu'), tf.keras.layers.Dropout(0.5), tf.keras.layers.Dense(10, activation='softmax') ]) model.summary() model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy']) train_images = train_images.reshape((60000, 28, 28, 1)) train_images = train_images / 255.0 test_images = test_images.reshape((10000, 28, 28, 1)) test_images = test_images / 255.0 model.fit(train_images, train_labels, epochs=5, batch_size=64) test_loss, test_acc = model.evaluate(test_images, test_labels, verbose=2) print("Test accuracy:", test_acc) ``` 这段代码中使用了两个卷积层分别提取32和64个特征,池化层进行特征加速和降维,全连接层作为最终分类器输出预测结果。在模型训练时,使用Adam优化器和交叉熵损失函数进行训练,经过5个epoch后可以得到约99%的测试准确率。 总之,通过使用TensorFlow实现MNIST手写数字识别的经历,可以深切认识到深度学习在计算机视觉领域中的应用,以及如何通过搭建和训练神经网络模型来解决实际问题。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 8
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值