用TensorFlow Lite 写个手写体识别 APP


今天有个网友在手把手教你在Android上搭建tensorflow Lite2.0这篇文章下评论

求问如何进行一个图像的输入和数组的输出?

我想这也是很多初学者的痛点,很多入门同学都没有完整从模型建立,训练,到转换成TensorFlowLite,并在Android中实际的用。

于是我就把我之前写的demo给了他,想想还是抽空把这个demo写成文章,希望能够给帮助到更多的入门的同学。

虽然基于TensorFlow 实现手写体的文章,一抓一大把,但是我还是有必要啰嗦下,毕竟它是很好的入门人工智能的实例。

我不关注的手写体识别算法的细节,关注整个从模型到应用的整个过程,想对算法了解的,请自行学习。

有兴趣的同学可以关注下我的系列博客人工智能系列(更新中……),自己也在学习这方面的知识,一起学习和交流。

1 手写体基础知识

1.1 探索MINIST数据集

采用的MNIST数据集,它来自美国国家标准与技术研究所,National Institute of Standards and Technology(NIST)。 训练集 (training set) 由来自 250 个不同人手写的数字构成,其中 50% 是高中学生,50% 来自人口普查局 (the Census Bureau) 的工作人员. 测试集(test set)也是同样比例的手写数字数据。

数据集中每张图片是什么样的呢?

就张这样子:
在这里插入图片描述
通过下面代码获得:

# Plot ad hoc mnist instances
from tensorflow.keras.datasets import mnist
import matplotlib.pyplot as plt

# load (downloaded if needed) the MNIST dataset
(X_train, y_train), (X_test, y_test) = mnist.load_data()
# plot 4 images as gray scale
plt.subplot(221)
plt.imshow(X_train[0], cmap=plt.get_cmap("gray"))
plt.subplot(222)
plt.imshow(X_train[1], cmap=plt.get_cmap("gray"))
plt.subplot(223)
plt.imshow(X_train[2], cmap=plt.get_cmap("gray"))
plt.subplot(224)
plt.imshow(X_train[3], cmap=plt.get_cmap("gray"))
# show the plot
plt.show()

但是实际上存储是什么呢?
在这里插入图片描述
你可以发现这是一个0字,存储是0这张图片的RGB的值,凡是值为零的地方都是黑色,非零的地方都是不同灰阶。这就是一张图片灰阶RGB矩阵。

1.2 CNN基本介绍

本次采用手写体识别算法就是CNN(卷积神经网络),在计算机视觉中应用比较广泛。

最为经典的CNN手写体识别图,描述了手写体识别的整个过程,具体的细节就不讲了,有机会写一篇这个算法细节的文章,但是本文神经网络模型结构如下:
CNN

1.3 基于TensorFlow 的手写体识别

采用TensorFlow 中Keras接口,比较适合新手使用。让你感觉创建神经网络模型就像是搭积木一样。

代码如下,留意注释。

import numpy
from tensorflow.keras.datasets import mnist
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense
from tensorflow.keras.layers import Dropout
from tensorflow.keras.layers import Flatten
from tensorflow.keras.layers import Conv2D
from tensorflow.keras.layers import MaxPooling2D
from tensorflow.python.keras.utils import np_utils
import tensorflow as tf
import pathlib

# fix random seed for reproducibility
seed = 7
numpy.random.seed(seed)
# load data
(X_train, y_train), (X_test, y_test) = mnist.load_data()
# reshape to be [samples][channels][width][height]
X_train = X_train.reshape(X_train.shape[0], 28, 28, 1).astype('float32')
X_test = X_test.reshape(X_test.shape[0], 28, 28, 1).astype('float32')

# normalize inputs from 0-255 to 0-1
X_train = X_train / 255
X_test = X_test / 255

print(X_train.shape)
# one hot encode outputs
y_train = np_utils.to_categorical(y_train)
y_test = np_utils.to_categorical(y_test)
print(X_train[0])

num_classes = y_test.shape[1]

def baseline_model():
    # create model
    model = Sequential()
    model.add(Conv2D(32, kernel_size=(5, 5),
                     input_shape=(28, 28, 1),//采用单通道的图片
                     activation='relu'))
    model.add(MaxPooling2D(pool_size=(2, 2)))
    model.add(Dropout(0.2))
    model.add(Flatten())
    model.add(Dense(128, activation='relu'))
    model.add(Dense(num_classes, activation='softmax'))
    # Compile model
    model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), optimizer='adam',
                  metrics=['accuracy'])
    return model


model = baseline_model()
# Fit the model
model.fit(X_train, y_train, validation_data=(X_test, y_test), epochs=10, batch_size=200, verbose=2)

# Final evaluation of the model
scores = model.evaluate(X_test, y_test, verbose=0)
print("CNN Error: %.2f%%" % (100 - scores[1] * 100))

# 上面升级网络训练的过程
# 下面需要将其转换tensorflow Lite模型,便于在Android中使用。
converter = tf.lite.TFLiteConverter.from_keras_model(model)
tflite_model = converter.convert()

tflite_model_file = pathlib.Path('saved_model/model.tflite')
tflite_model_file.write_bytes(tflite_model)

2 在Android实现手写体识别

如果你不知道如何配置Android的环境,请参考手把手教你在Android上搭建tensorflow Lite2.0

2.1 加载模型

将训练好的TensorFlow Lite 文件放在Android的asset文件夹下。

public class TF {
    private static Context mContext;
    Interpreter mInterpreter;
    private static TF instance;

    public static TF newInstance(Context context) {
        mContext = context;
        if (instance == null) {
            instance = new TF();
        }
        return instance;
    }

    Interpreter get() {
        try {
            if (Objects.isNull(mInterpreter))
                mInterpreter = new Interpreter(loadModelFile(mContext));
        } catch (IOException e) {
            e.printStackTrace();
        }
        return mInterpreter;
    }

    // 获取文件
    private MappedByteBuffer loadModelFile(Context context) throws IOException {
        AssetFileDescriptor fileDescriptor = context.getAssets().openFd("model.tflite");
        FileInputStream inputStream = new FileInputStream(fileDescriptor.getFileDescriptor());
        FileChannel fileChannel = inputStream.getChannel();
        long startOffset = fileDescriptor.getStartOffset();
        long declaredLength = fileDescriptor.getDeclaredLength();
        return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength);
    }
}

2.2 自定义写画View

public class HandWriteView extends View {
    Path mPath = new Path();
    Paint mPaint;

    Bitmap mBitmap;
    Canvas mCanvas;

    public HandWriteView(Context context) {
        super(context);
        init();
    }

    public HandWriteView(Context context, AttributeSet attrs) {
        super(context, attrs);
        init();
    }

    void init() {
        mPaint = new Paint();
        mPaint.setColor(Color.WHITE);
        mPaint.setStyle(Paint.Style.STROKE);
        mPaint.setStrokeJoin(Paint.Join.ROUND);
        mPaint.setStrokeCap(Paint.Cap.ROUND);
        mPaint.setStrokeWidth(30);

    }

    @Override
    protected void onDraw(Canvas canvas) {
        super.onDraw(canvas);
        mBitmap = Bitmap.createBitmap(getWidth(), getHeight(), Bitmap.Config.ARGB_8888);
        mCanvas = new Canvas(mBitmap);
        mCanvas.drawColor(Color.BLACK);
        canvas.drawPath(mPath, mPaint);
        mCanvas.drawPath(mPath, mPaint);
    }

    @Override
    public boolean onTouchEvent(MotionEvent event) {
        switch (event.getAction()) {
            case MotionEvent.ACTION_DOWN:
                mPath.moveTo(event.getX(), event.getY());
                break;
            case MotionEvent.ACTION_MOVE:
                mPath.lineTo(event.getX(), event.getY());
                break;
            case MotionEvent.ACTION_UP:
            case MotionEvent.ACTION_CANCEL:
                break;
        }
        postInvalidate();
        return true;
    }

    Bitmap getBitmap() {
        mPath.reset();
        return mBitmap;
    }
}

2.3 将bitmap转成网络需要的格式

因为数据集中的数据都是28 * 28 * 3的,28为图片的宽和高,3为R,G,B三个通道,所以在输入到网络之前,我们需要将bitmap转成网络需要的格式。

private ByteBuffer convertBitmapToByteBuffer(Bitmap bitmap) {
        int inputShape[] = TF.newInstance(getApplicationContext()).get().getInputTensor(0).shape();
        int inputImageWidth = inputShape[1];
        int inputImageHeight = inputShape[2];
        Bitmap bs = Bitmap.createScaledBitmap(bitmap, inputImageWidth, inputImageHeight, true);
        mImageView.setImageBitmap(bs);
        ByteBuffer byteBuffer = ByteBuffer.allocateDirect(4 * inputImageHeight * inputImageWidth);
        byteBuffer.order(ByteOrder.nativeOrder());

        int[] pixels = new int[inputImageWidth * inputImageHeight];
        bs.getPixels(pixels, 0, bs.getWidth(), 0, 0, bs.getWidth(), bs.getHeight());

        for (int pixelValue : pixels) {
            int r = (pixelValue >> 16 & 0xFF);
            int g = (pixelValue >> 8 & 0xFF);
            int b = (pixelValue & 0xFF);

            // Convert RGB to grayscale and normalize pixel value to [0..1]
            float normalizedPixelValue = (r + g + b) / 3.0f / 255.0f;
            byteBuffer.putFloat(normalizedPixelValue);
        }
        return byteBuffer;
    }

2.4 识别结果的输出

识别的结果是根据0-9的概率进行判断,概率最大的就是识别的结果。

float[][] input = new float[1][10];
TF.newInstance(getApplicationContext()).get().run(convertBitmapToByteBuffer(mHandWriteView.getBitmap()), input);
int result = -1;
float value = 0f;
for (int j = 0; j < 10; j++) {
    if (input[0][j] > value) {
        value = input[0][j];
        result = j;
    }
Log.i("TAG", "result: " + j + " " + input[0][j]);
}
if (input[0][result] < 0.2f) {
    mTextView.setText("结果为:未识别");
} else {
    mTextView.setText("结果为:" + result);
}

识别结果:
在这里插入图片描述

若有需要,请自行点击demo下载。

3 总结

开发一个人工智能APP的主要流程就这么多,关键还是在于算法,要想得到更为精准的模型,除了要采用更好的模型之外,还需要对数据进行旋转,增强或者白质化,来提高数据的多样性。

欢迎大家一起交流!!!!

  • 8
    点赞
  • 44
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 22
    评论
手写体识别是一种常见的人工智能应用,它可以识别人们书的手字母或数字。而利用 TensorFlow Lite 技术开发手写体识别 app 可以实现更快速、更准确的识别体验。下面是关于如何用 TensorFlow Lite 一个手写体识别 app 的介绍。 首先,我们需要建立一个手数字识别模型,并将其转换为 TensorFlow Lite 模型。可以使用 TensorFlow 的 Keras 库来训练模型,使用 MNIST 数据集或其他手数字数据集。训练完成后,需要使用 TensorFlow Lite 转换器将模型转换为 TensorFlow Lite 模型,以便在移动设备上运行。 接下来,需要使用移动开发工具来开发手写体识别 app。Android 和 iOS 都支持 TensorFlow Lite,我们可以使用对应的 Android Studio 和 XCode 来进行应用开发。在开发过程中,需要将 TensorFlow Lite 模型嵌入到 app 中,并根据 app需要编相应的代码。 对于 Android 应用开发,可以使用 TensorFlow Lite Android 集成库来轻松地将 TensorFlow Lite 模型嵌入到应用中,并在需要识别数字时调用模型计算。同时,也可以使用 Android 自带的手输入功能来获取用户手输入的数字,在模型计算之前对输入进行处理和预处理。 对于 iOS 应用开发,可以使用 TensorFlow Lite iOS 集成库来导入 TensorFlow Lite 模型,并使用 Core ML 框架将其集成到应用中。此外,还可以使用 iOS 13 新增的手输入框架获取用户输入的手数字,并将其输入模型进行识别。 在完成开发和测试后,可以将 app 发布到相应的应用商店,供用户下载和使用。 总之,利用 TensorFlow Lite 技术开发手写体识别 app 是一个非常实用的人工智能应用开发项目,可以为用户提供便捷、准确的手数字识别功能。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 22
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

技术人Howzit

钱不钱的无所谓,这是一种鼓励!

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值