手写数字识别--Android Studio 加载tensorflow模型

版权声明:本文为博主原创文章,遵循 CC 4.0 by-sa 版权协议,转载请附上原文出处链接和本声明。
本文链接:https://blog.csdn.net/qq_38956849/article/details/85041381

思路:

    在电脑端调用mnist数据集,构建深度卷积神经网络模型,使用TensorFlow进行训练,达到99%+的测试集数据准确率,继而把模型迁移到App端。具体迁移教程参考之前的文章:

https://mp.csdn.net/postedit/85016120这是模型迁移的教程

https://mp.csdn.net/postedit/85009068这是训练神经网络的代码

下面详细讲解本项目的具体实施方式:

    搭建环境就不多说了,之前的文章讲解的很清楚。

    训练模型的过程也不多说了,之前的文章有详细的代码可以参考。

一、界面设计

<?xml version="1.0" encoding="utf-8"?>
<LinearLayout xmlns:android="http://schemas.android.com/apk/res/android"
    android:layout_width="match_parent"
    android:layout_height="match_parent"
    android:orientation="vertical"
    android:paddingBottom="16dp"
    android:paddingLeft="16dp"
    android:paddingRight="16dp"
    android:paddingTop="16dp">

    <TextView
        android:layout_weight="1"
        android:id="@+id/txt_id"
        android:layout_width="match_parent"
        android:layout_height="wrap_content"
        android:gravity="center"
        android:textSize="12pt"/>

    <ImageView
        android:layout_weight="6"
        android:id="@+id/imageView1"
        android:layout_width="match_parent"
        android:layout_height="wrap_content"
        android:layout_gravity="center"/>

    <LinearLayout
        android:layout_width="match_parent"
        android:layout_height="wrap_content">
        <Button
            android:id="@+id/btn_mnist"
            android:layout_weight="1"
            android:layout_width="match_parent"
            android:layout_height="wrap_content"
            android:text="识别" />
        <Button
            android:id="@+id/btn_photo"
            android:layout_weight="1"
            android:layout_width="match_parent"
            android:layout_height="wrap_content"
            android:text="拍照" />
    </LinearLayout>

</LinearLayout>

二、活动MnistActivity代码

package com.example.tan.tfmodel;

import android.content.Intent;
import android.graphics.Bitmap;
import android.graphics.BitmapFactory;
import android.provider.MediaStore;
import android.support.v7.app.AppCompatActivity;
import android.os.Bundle;
import android.view.View;
import android.widget.Button;
import android.widget.ImageView;
import android.widget.TextView;
import android.widget.Toast;

public class ActMnist extends AppCompatActivity {
    TextView txt;
    Bitmap bitmap;
    ImageView imageView;

    //模型存放路径
    PredictionTF preTF;
    private static final String MODEL_FILE = "file:///android_asset/mnist.pb";

    @Override
    protected void onCreate(Bundle savedInstanceState) {
        super.onCreate(savedInstanceState);
        setContentView(R.layout.activity_act_mnist);

        //输入模型存放路径,并加载TensoFlow模型
        preTF =new PredictionTF(getAssets(),MODEL_FILE);

        txt=(TextView)findViewById(R.id.txt_id);
        imageView =(ImageView)findViewById(R.id.imageView1);
        bitmap = BitmapFactory.decodeResource(getResources(), R.drawable.test_image);
        imageView.setImageBitmap(bitmap);

        Button btn = findViewById(R.id.btn_mnist);
        btn.setOnClickListener(new View.OnClickListener(){
            @Override
            public void onClick(View view) {
                int result= preTF.getPredict(bitmap);
                String ret = "预测结果为:"+String.valueOf(result);
                txt.setText(ret);
                Toast.makeText(getApplicationContext(), ret, Toast.LENGTH_SHORT).show();
            }
        });

        btn = findViewById(R.id.btn_photo);
        btn.setOnClickListener(new View.OnClickListener() {
            @Override
            public void onClick(View view) {
                try{
                    Intent it = new Intent(MediaStore.ACTION_IMAGE_CAPTURE);
                    startActivityForResult(it,11);
                }catch (Exception e){
                    Toast.makeText(getApplicationContext(), "get_photo 出错", Toast.LENGTH_SHORT).show();
                }
            }
        });
    }

    @Override
    protected void onActivityResult(int requestCode, int resultCode, Intent data) {
        super.onActivityResult(requestCode, resultCode, data);
//        Toast.makeText(this, "onActivityResult:"+String.valueOf(resultCode), Toast.LENGTH_SHORT).show();
        try{
            bitmap = (Bitmap)data.getExtras().get("data");
            imageView.setImageBitmap(bitmap);
            int result= preTF.getPredict(bitmap);
            String ret = "预测结果为:"+String.valueOf(result);
            txt.setText(ret);
            Toast.makeText(this, ret, Toast.LENGTH_SHORT).show();
        }catch(Exception e){
            Toast.makeText(this, e.toString(), Toast.LENGTH_SHORT).show();
        }
    }
}



三、最重要的就是封装成类的模型代码

package com.example.tan.tfmodel;

import android.content.res.AssetManager;
import android.graphics.Bitmap;
import android.graphics.Color;
import android.graphics.Matrix;
import android.util.Log;

import org.tensorflow.contrib.android.TensorFlowInferenceInterface;


public class PredictionTF {
    private static final String TAG = "PredictionTF";
    //设置模型输入/输出节点的数据维度
    private static final int IN_COL = 1;
    private static final int IN_ROW = 28*28;
    private static final int OUT_COL = 1;
    private static final int OUT_ROW = 1;
    //模型中输入变量的名称
    private static final String inputName = "input/x_input";
    //模型中输出变量的名称
    private static final String outputName = "output";

    TensorFlowInferenceInterface tfInfer;
    static {//加载libtensorflow_inference.so库文件
        System.loadLibrary("tensorflow_inference");
        Log.e(TAG,"libtensorflow_inference.so库加载成功");
    }

    PredictionTF(AssetManager assetManager, String modePath) {
        //初始化TensorFlowInferenceInterface对象
        tfInfer = new TensorFlowInferenceInterface(assetManager,modePath);
        Log.e(TAG,"TensoFlow模型文件加载成功");
    }

    /**
     *  利用训练好的TensoFlow模型预测结果
     * @param bitmap 输入被测试的bitmap图
     * @return 返回预测结果,int
     */
    public int getPredict(Bitmap bitmap) {
        float[] inputdata = bitmapToFloatArray(bitmap,28,28);//需要将图片缩放带28*28
        //将数据feed给tensorflow的输入节点
        tfInfer.feed(inputName, inputdata, IN_COL, IN_ROW);

        //运行tensorflow
        String[] outputNames = new String[] {outputName};
        tfInfer.run(outputNames);

        ///获取输出节点的输出信息
        int[] outputs = new int[OUT_COL*OUT_ROW]; //用于存储模型的输出数据
        tfInfer.fetch(outputName, outputs);
        return outputs[0];
    }

    /**
     * 将bitmap转为(按行优先)一个float数组,并且每个像素点都归一化到0~1之间。
     * @param bitmap 输入被测试的bitmap图片
     * @param rx 将图片缩放到指定的大小(列)->28
     * @param ry 将图片缩放到指定的大小(行)->28
     * @return   返回归一化后的一维float数组 ->28*28
     */
    public static float[] bitmapToFloatArray(Bitmap bitmap, int rx, int ry){
        int height = bitmap.getHeight();
        int width = bitmap.getWidth();
        // 计算缩放比例
        float scaleWidth = ((float) rx) / width;
        float scaleHeight = ((float) ry) / height;
        Matrix matrix = new Matrix();
        matrix.postScale(scaleWidth, scaleHeight);
        bitmap = Bitmap.createBitmap(bitmap, 0, 0, width, height, matrix, true);
        Log.i(TAG,"bitmap width:"+bitmap.getWidth()+",height:"+bitmap.getHeight());
        Log.i(TAG,"bitmap.getConfig():"+bitmap.getConfig());
        height = bitmap.getHeight();
        width = bitmap.getWidth();
        float[] result = new float[height*width];
        int k = 0;
        //行优先
        for(int j = 0;j < height;j++){
            for (int i = 0;i < width;i++){
                int argb = bitmap.getPixel(i,j);
                int r = Color.red(argb);
                int g = Color.green(argb);
                int b = Color.blue(argb);
                int a = Color.alpha(argb);
                //由于是灰度图,所以r,g,b分量是相等的。
                assert(r==g && g==b);
//                Log.i(TAG,i+","+j+" : argb = "+argb+", a="+a+", r="+r+", g="+g+", b="+b);
                result[k++] = r / 255.0f;
            }
        }
        return result;
    }
}

 

可以看出,识别虽然错误,却能够正常运行,希望接下来可以把模型改进一下。

展开阅读全文

没有更多推荐了,返回首页