android数字识别代码,Android集成TensorFlow使用Mnist数据集实现手写数字识别

概述

最想想学习一点Ai相关的东西,所有就简单实现了一个手写数字识别的项目,虽然其中很多的东西不是太明白,还需要自己不断的探索,这里就把目前的所学先记录下来。

Android端的实现

1、集成TensorFlow

网上很多集成TensorFlow的方法很复杂,需要编译源码,其实没有那么复杂,当然你也可以安装哪些步骤把源码下载下来进行编译集成,我是使用简单的集成方式,在Android工程下引入就行,代码如下:

implementation 'org.tensorflow:tensorflow-android:+'

// implementation 'org.tensorflow:tensorflow-android:1.13.1'

2、导入跨平台的模型pb文件

这里涉及模型的训练,这个相对来说还是比较复杂的,涉及到文件大小的优化和识别精准度的问题,我目前也训练出了几个模型但是精准度还是没有达到我的预期,但是刚开始学习还是勉强够用了。

7970d03c92f3

image.png

3、实现手写数字控件

这个就是自定义一个控件,在画布上书写数字,再拿到带有数字的bitmap对象。

package com.stormdzh.tfmnist.handwrite;

import android.content.Context;

import android.content.res.Resources;

import android.graphics.Bitmap;

import android.graphics.Canvas;

import android.graphics.Paint;

import android.graphics.Path;

import android.util.AttributeSet;

import android.view.MotionEvent;

import android.view.View;

import com.stormdzh.tfmnist.R;

/**

* @Description: 自定义的view实现手写数字

* @Author: dzh

* @CreateDate: 2020-05-15 22:43

*/

public class MyPaintView extends View {

private Resources myResources;

// 画笔,定义绘制属性

private Paint myPaint;

private Paint mBitmapPaint;

// 绘制路径

private Path myPath;

// 画布及其底层位图

private Bitmap myBitmap;

private Canvas myCanvas;

private float mX, mY;

private static final float TOUCH_TOLERANCE = 4;

// 记录宽度和高度

private int mWidth;

private int mHeight;

public MyPaintView(Context context) {

super(context);

initialize();

}

public MyPaintView(Context context, AttributeSet attrs, int defStyle) {

super(context, attrs, defStyle);

initialize();

}

public MyPaintView(Context context, AttributeSet attrs) {

super(context, attrs);

initialize();

}

/**

* 初始化工作

*/

private void initialize() {

myResources = getResources();

// 绘制自由曲线用的画笔

myPaint = new Paint();

myPaint.setAntiAlias(true);

myPaint.setDither(true);

myPaint.setColor(myResources.getColor(R.color.white));

myPaint.setStyle(Paint.Style.STROKE);

myPaint.setStrokeJoin(Paint.Join.ROUND);

myPaint.setStrokeCap(Paint.Cap.ROUND);

myPaint.setStrokeWidth(88);

myPath = new Path();

mBitmapPaint = new Paint(Paint.DITHER_FLAG);

}

@Override

protected void onSizeChanged(int w, int h, int oldw, int oldh) {

super.onSizeChanged(w, h, oldw, oldh);

mWidth = w;

mHeight = h;

myBitmap = Bitmap.createBitmap(w, h, Bitmap.Config.ARGB_8888);

myBitmap.eraseColor(myResources.getColor(R.color.purple_dark));

myCanvas = new Canvas(myBitmap);

}

@Override

public boolean onTouchEvent(MotionEvent event) {

float x = event.getX();

float y = event.getY();

switch (event.getAction()) {

case MotionEvent.ACTION_DOWN:

touch_start(x, y);

invalidate();

break;

case MotionEvent.ACTION_MOVE:

touch_move(x, y);

invalidate();

break;

case MotionEvent.ACTION_UP:

touch_up();

invalidate();

break;

}

return true;

}

@Override

protected void onDraw(Canvas canvas) {

super.onDraw(canvas);

// 如果不调用这个方法,绘制结束后画布将清空

canvas.drawBitmap(myBitmap, 0, 0, mBitmapPaint);

// 绘制路径

canvas.drawPath(myPath, myPaint);

}

private void touch_start(float x, float y) {

myPath.reset();

myPath.moveTo(x, y);

mX = x;

mY = y;

}

private void touch_move(float x, float y) {

float dx = Math.abs(x - mX);

float dy = Math.abs(y - mY);

if (dx >= TOUCH_TOLERANCE || dy >= TOUCH_TOLERANCE) {

myPath.quadTo(mX, mY, (x + mX) / 2, (y + mY) / 2);

mX = x;

mY = y;

}

}

private void touch_up() {

myPath.lineTo(mX, mY);

// commit the path to our offscreen

// 如果少了这一句,笔触抬起时myPath重置,那么绘制的线将消失

myCanvas.drawPath(myPath, myPaint);

// kill this so we don't double draw

myPath.reset();

}

/**

* 清除整个图像

*/

public void clear() {

// 清除方法1:重新生成位图

// myBitmap = Bitmap

// .createBitmap(mWidth, mHeight, Bitmap.Config.ARGB_8888);

// myCanvas = new Canvas(myBitmap);

// 清除方法2:将位图清除为白色

myBitmap.eraseColor(myResources.getColor(R.color.purple_dark));

// 两种清除方法都必须加上后面这两步:

// 路径重置

myPath.reset();

// 刷新绘制

invalidate();

}

public Bitmap getBitMap() {

return myBitmap;

}

}

4、实现布局的编写

android:layout_width="match_parent"

android:layout_height="match_parent"

android:background="#80300900"

android:gravity="center_horizontal"

android:orientation="vertical"

android:paddingLeft="16dp"

android:paddingTop="16dp"

android:paddingRight="16dp"

android:paddingBottom="16dp">

android:layout_width="wrap_content"

android:layout_height="wrap_content"

android:layout_gravity="center"

android:text="点击下面按钮可以实现测试不同数字" />

android:id="@+id/btnTest"

android:layout_width="match_parent"

android:layout_height="wrap_content"

android:text="测试" />

android:id="@+id/imgPrevieww"

android:layout_width="wrap_content"

android:layout_height="wrap_content"

android:layout_gravity="center" />

android:id="@+id/tvResult"

android:layout_width="match_parent"

android:layout_height="wrap_content"

android:gravity="center"

android:text="未知" />

android:id="@+id/mMyPaintView"

android:layout_width="320dp"

android:layout_height="320dp"

android:layout_marginTop="10dp"

android:background="#000000" />

android:layout_marginTop="10dp"

android:gravity="center_horizontal"

android:layout_width="match_parent"

android:layout_height="40dp"

android:orientation="horizontal">

android:id="@+id/btnClear"

android:layout_width="120dp"

android:layout_height="40dp"

android:text="清空" />

android:id="@+id/btnOk"

android:layout_width="120dp"

android:layout_height="40dp"

android:text="识别" />

5、实现一个预测的工具类,调用加载模型和实现预测基本方法

package com.stormdzh.tfmnist;

import android.content.res.AssetManager;

import android.graphics.Bitmap;

import android.graphics.Color;

import android.graphics.Matrix;

import android.util.Log;

import org.tensorflow.contrib.android.TensorFlowInferenceInterface;

public class PredictionTF {

private static final String TAG = "PredictionTF";

//设置模型输入/输出节点的数据维度

private static final int IN_COL = 1;

private static final int IN_ROW = 28 * 28;

private static final int OUT_COL = 1;

private static final int OUT_ROW = 1;

//模型中输入变量的名称

// private static final String inputName = "x_input";

// private static final String inputName = "regression/Placeholder";

private static String inputName = "convolutional/x";

//模型中输出变量的名称

private static String outputName = "output";

private TensorFlowInferenceInterface inferenceInterface;

PredictionTF(AssetManager assetManager, String modePath) {

//初始化TensorFlowInferenceInterface对象

inferenceInterface = new TensorFlowInferenceInterface(assetManager, modePath);

Log.e(TAG, "模型文件加载成功");

}

/**

* 利用训练好的TensoFlow模型预测结果

*

* @param bitmap 输入被测试的bitmap图

* @return 返回预测结果,int数组

*/

public int[] getPredict(Bitmap bitmap) {

float[] inputdata = bitmapToFloatArray(bitmap, 28, 28);//需要将图片缩放带28*28

//将数据feed给tensorflow的输入节点

if(MainActivity.isRegression){

inputName="regression/Placeholder";

}else{

inputName = "convolutional/x";

}

inferenceInterface.feed(inputName, inputdata, IN_COL, IN_ROW);

if(!MainActivity.isRegression) {

float[] ss = new float[]{0.5f};

inferenceInterface.feed("convolutional/keep_prob", ss);

}

//运行tensorflow

String[] outputNames = new String[]{outputName};

inferenceInterface.run(outputNames);

///获取输出节点的输出信息

int[] outputs = new int[OUT_COL * OUT_ROW]; //用于存储模型的输出数据

inferenceInterface.fetch(outputName, outputs);

return outputs;

}

/**

* 将bitmap转为(按行优先)一个float数组,并且每个像素点都归一化到0~1之间。

*

* @param bitmap 输入被测试的bitmap图片

* @param rx 将图片缩放到指定的大小(列)->28

* @param ry 将图片缩放到指定的大小(行)->28

* @return 返回归一化后的一维float数组 ->28*28

*/

public static float[] bitmapToFloatArray(Bitmap bitmap, int rx, int ry) {

int height = bitmap.getHeight();

int width = bitmap.getWidth();

// 计算缩放比例

float scaleWidth = ((float) rx) / width;

float scaleHeight = ((float) ry) / height;

Matrix matrix = new Matrix();

matrix.postScale(scaleWidth, scaleHeight);

bitmap = Bitmap.createBitmap(bitmap, 0, 0, width, height, matrix, true);

Log.i(TAG, "bitmap width:" + bitmap.getWidth() + ",height:" + bitmap.getHeight());

Log.i(TAG, "bitmap.getConfig():" + bitmap.getConfig());

height = bitmap.getHeight();

width = bitmap.getWidth();

float[] result = new float[height * width];

int k = 0;

//行优先

for (int j = 0; j < height; j++) {

for (int i = 0; i < width; i++) {

int argb = bitmap.getPixel(i, j);

int r = Color.red(argb);

int g = Color.green(argb);

int b = Color.blue(argb);

int a = Color.alpha(argb);

//由于是灰度图,所以r,g,b分量是相等的。

assert (r == g && g == b);

// Log.i(TAG,i+","+j+" : argb = "+argb+", a="+a+", r="+r+", g="+g+", b="+b);

result[k++] = r / 255.0f;

}

}

return result;

}

}

6、在MainActivity中加载布局和调用预测工具类

package com.stormdzh.tfmnist;

import android.graphics.Bitmap;

import android.graphics.BitmapFactory;

import android.graphics.Matrix;

import android.os.Bundle;

import android.util.Log;

import android.view.View;

import android.widget.ImageView;

import android.widget.TextView;

import androidx.appcompat.app.AppCompatActivity;

import com.stormdzh.tfmnist.handwrite.MyPaintView;

public class MainActivity extends AppCompatActivity implements View.OnClickListener {

public static boolean isRegression = false; //true 使用线性模型

private static final String TAG = "MainActivity";

// private String MODEL_FILE = "file:///android_asset/mnist_dzh.pb"; //模型存放路径

// private String MODEL_FILE = "file:///android_asset/mnist_regression.pb"; //模型存放路径

private String MODEL_FILE = "file:///android_asset/mnist_convolutional.pb"; //模型存放路径

private TextView tvResult;

private ImageView imgPrevieww;

private Bitmap bitmap;

private PredictionTF preTF;

private int index = 0;

private MyPaintView mMyPaintView;

@Override

protected void onCreate(Bundle savedInstanceState) {

super.onCreate(savedInstanceState);

setContentView(R.layout.activity_main);

findViewById(R.id.btnTest).setOnClickListener(this);

findViewById(R.id.btnClear).setOnClickListener(this);

findViewById(R.id.btnOk).setOnClickListener(this);

tvResult = (TextView) findViewById(R.id.tvResult);

imgPrevieww = (ImageView) findViewById(R.id.imgPrevieww);

mMyPaintView = findViewById(R.id.mMyPaintView);

getBitmap();

if (isRegression) {

MODEL_FILE = "file:///android_asset/mnist_regression.pb"; //模型存放路径

} else {

MODEL_FILE = "file:///android_asset/mnist_convolutional.pb"; //模型存放路径

}

preTF = new PredictionTF(getAssets(), MODEL_FILE);//输入模型存放路径,并加载TensoFlow模型

}

private Bitmap getBitmap() {

switch (index) {

case 0:

bitmap = BitmapFactory.decodeResource(getResources(), R.drawable.n0);

break;

case 1:

bitmap = BitmapFactory.decodeResource(getResources(), R.drawable.n1);

break;

case 2:

bitmap = BitmapFactory.decodeResource(getResources(), R.drawable.n2);

break;

case 3:

bitmap = BitmapFactory.decodeResource(getResources(), R.drawable.n3);

break;

case 4:

bitmap = BitmapFactory.decodeResource(getResources(), R.drawable.n4);

break;

case 5:

bitmap = BitmapFactory.decodeResource(getResources(), R.drawable.n5);

break;

case 6:

bitmap = BitmapFactory.decodeResource(getResources(), R.drawable.n6);

break;

case 7:

bitmap = BitmapFactory.decodeResource(getResources(), R.drawable.n7);

break;

case 8:

bitmap = BitmapFactory.decodeResource(getResources(), R.drawable.n8);

break;

case 9:

bitmap = BitmapFactory.decodeResource(getResources(), R.drawable.n9);

break;

}

imgPrevieww.setImageBitmap(bitmap);

return bitmap;

}

@Override

public void onClick(View view) {

switch (view.getId()) {

case R.id.btnTest:

if (index > 9)

index = 0;

bitmap = getBitmap();

Log.i(TAG, "sourceBitmap=>" + bitmap.getWidth() + " :" + bitmap.getHeight());

index++;

recogBitmap(bitmap);

break;

case R.id.btnClear:

mMyPaintView.clear();

break;

case R.id.btnOk:

// Bitmap viewBitmap = convertViewToBitmap(mMyPaintView);

Bitmap viewBitmap = mMyPaintView.getBitMap();

// imgPrevieww.setImageBitmap(viewBitmap);

Bitmap finalBitmap = scaledBitmap(viewBitmap);

imgPrevieww.setImageBitmap(finalBitmap);

Log.i(TAG, "finalBitmap=>" + finalBitmap.getWidth() + " :" + finalBitmap.getHeight());

recogBitmap(finalBitmap);

break;

}

}

private Bitmap scaledBitmap(Bitmap viewBitmap) {

int width = viewBitmap.getWidth();

float scale = 74f / width;

Matrix matrix = new Matrix();

matrix.setScale(scale, scale);

return Bitmap.createBitmap(viewBitmap, 0, 0, viewBitmap.getWidth(),

viewBitmap.getHeight(), matrix, true);

}

private void recogBitmap(Bitmap bitmap) {

String res = "图片识别结果为:";

int[] result = preTF.getPredict(bitmap);

for (int i = 0; i < result.length; i++) {

Log.i(TAG, res + result[i]);

res = res + String.valueOf(result[i]) + " ";

}

tvResult.setText(res);

}

public Bitmap convertViewToBitmap(View view) {

view.measure(View.MeasureSpec.makeMeasureSpec(0, View.MeasureSpec.UNSPECIFIED), View.MeasureSpec.makeMeasureSpec(0, View.MeasureSpec.UNSPECIFIED));

view.layout(0, 0, view.getMeasuredWidth(), view.getMeasuredHeight());

view.buildDrawingCache();

Bitmap bitmap = view.getDrawingCache();

return bitmap;

}

}

效果

首先看下工程运行后的界面:

7970d03c92f3

image.png

点击测试按钮可以依次循环测试我添加的10中0-9的数字,这个写数字的识别率是100%。

黑色区域是手写区域,有清空和识别两个按钮,清空是清空画布,识别就是开始预测。

例如手写“4”的识别结果:

7970d03c92f3

image.png

目前demo中是使用卷积模型识别的,有些数字的写的歪了等异常情况是识别有误的,这个以后还需要继续优化。代码可以参考我github工程:TFMnist

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值