基于tensorflow lite的手写字识别app设计实验

实现过程

首先在pycharm中利用tensorflow框架训练手写字识别神经网络模型,然后保存为.pd文件,再将该文件转为.tflite格式,作为app的模型,基于该模型实现app的功能。

各版本说明

python 3
tensorflow 2.3.0
Android studio 4

基于mnist数据集的手写字神经网络训练

import tensorflow as tf
from tensorflow import keras

num_epochs = 10
batch_size = 50
learning_rate = 0.001

mnist = keras.datasets.mnist
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()

train_images = train_images / 255.0
test_images = test_images / 255.0

model = keras.Sequential([
    keras.layers.Flatten(),
    keras.layers.Dense(100, activation=tf.nn.relu),
    keras.layers.Dense(10),
    keras.layers.Softmax()
])

model.compile(
    optimizer=keras.optimizers.Adam(learning_rate=0.001),
    loss=keras.losses.sparse_categorical_crossentropy,
    metrics=[keras.metrics.sparse_categorical_accuracy]
)
model.fit(train_images, train_labels, epochs=num_epochs, batch_size=batch_size)

test_loss, test_acc = model.evaluate(test_images, test_labels)
print(test_loss)
print('\nTest accuracy:', test_acc)


model.save('saved/1')

保存为.tflite文件

我之前看到的参考案例一般在cmd中实现文件转换,但是我在尝试过程中感觉比较麻烦,所以就直接在pycharm中进行转换

import tensorflow as tf

saved_model_path = r'C:\Users\49103\PycharmProjects\deepL1\saved\1'

converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_path)

# converter=tf.lite.TFLiteConverter.from_saved_model(saved_model_dir=in_path,input_arrays=[input_tensor_name],output_arrays=[class_tensor_name])
tflite_model = converter.convert()

with open('model.tflite', 'wb') as f:
    f.write(tflite_model)

app项目结构

在这里插入图片描述
从mnist数据集中解压了图片并选取一张作为测试图片,与tflite模型一起放在assets资源文件目录下作为app功能测试图片。

Android studio程序

public class MainActivity extends AppCompatActivity {
    private String mModelName="model";
    private Button add_image,load_model;
    private TextView result_text;
    private ImageView show_image;
    private boolean load_result=false;
    private Interpreter tflite=null;
    private int[] ddims={1,1,28,28};
    private static String mTestImage="mnist_test_3.png";


    @Override
    protected void onCreate(Bundle savedInstanceState) {
        super.onCreate(savedInstanceState);
        setContentView(R.layout.activity_main);

        checkPermission();
        init();

    }

    private void init() {
        add_image=findViewById(R.id.add_image);
        load_model=findViewById(R.id.load_model);
        result_text=findViewById(R.id.result_text);
        show_image=findViewById(R.id.show_image);

        load_model.setOnClickListener((view)-> {
            try {
                tflite=new Interpreter(loadModelFile(MainActivity.this));
                Toast.makeText(MainActivity.this,"load model success",Toast.LENGTH_SHORT).show();
                tflite.setNumThreads(1);
                load_result=true;
            }catch(IOException e){
                Toast.makeText(MainActivity.this,"load model false",Toast.LENGTH_SHORT).show();
                load_result=false;
                e.printStackTrace();
            }
        });
        AssetManager assetManager=this.getAssets();
        add_image.setOnClickListener((view)->{
            if(!load_result){
                Toast.makeText(MainActivity.this,"never load model",Toast.LENGTH_SHORT).show();
                return;
            }
            try {
                InputStream inputStream=assetManager.open(mTestImage);
                Bitmap bitmap=BitmapFactory.decodeStream(inputStream);
                show_image.setImageBitmap(bitmap);
            } catch (IOException e) {
                e.printStackTrace();
            }
            predict_image(mTestImage);

        });

    }


    /** Memory-map the model file in Assets. */
    private MappedByteBuffer loadModelFile(Activity activity) throws IOException {
        AssetFileDescriptor fileDescriptor = activity.getAssets().openFd(mModelName+".tflite");
        FileInputStream inputStream = new FileInputStream(fileDescriptor.getFileDescriptor());
        FileChannel fileChannel = inputStream.getChannel();
        long startOffset = fileDescriptor.getStartOffset();
        long declaredLength = fileDescriptor.getDeclaredLength();
        return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength);
    }


    private void predict_image(String image_path) {
        // picture to float array
        Bitmap bmp = getScaleBitmap(image_path);
        ByteBuffer inputData = getScaledMatrix(bmp, ddims);
        try {
            // Data format conversion takes too long
            // Log.d("inputData", Arrays.toString(inputData));
            float[][] labelProbArray = new float[1][10];
            long start = System.currentTimeMillis();
            // get predict result
            tflite.run(inputData, labelProbArray);
            long end = System.currentTimeMillis();
            long time = end - start;
            float[] results = new float[labelProbArray[0].length];
            //System.arraycopy把一个数组中某一段字节数据放到另一个数组中
            System.arraycopy(labelProbArray[0], 0, results, 0, labelProbArray[0].length);
            // show predict result and time
            int r = get_max_result(results);
            String show_text = "result:" + r +  "\nprobability:" + results[r] + "\ntime:" + time + "ms";
            result_text.setText(show_text);
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    public static ByteBuffer getScaledMatrix(Bitmap bitmap, int[] ddims) {
        //基于新分配的内存块创建直接字节缓冲区。
        ByteBuffer imgData = ByteBuffer.allocateDirect(ddims[0] * ddims[1] * ddims[2] * ddims[3] * 4);
        //order:设置此缓冲区的字节顺序。ByteOrder.nativeOrder():返回当前平台字节顺序。
        imgData.order(ByteOrder.nativeOrder());
        // get image pixel
        int[] pixels = new int[ddims[2] * ddims[3]];
        //从当前存在的位图,按一定的比例创建一个新的位图。
        Bitmap bm = Bitmap.createScaledBitmap(bitmap, ddims[2], ddims[3], false);
        bm.getPixels(pixels, 0, ddims[2], 0, 0, ddims[2], ddims[3]);
        int pixel = 0;
        for (int i = 0; i < ddims[2]; ++i) {
            for (int j = 0; j < ddims[3]; ++j) {
                final int val = pixels[pixel++];
                imgData.putFloat((((val & 0xFF) - 0) / 255.0f));
            }
        }

        if (bm.isRecycled()) {
            bm.recycle();
        }
        return imgData;
    }


    private int get_max_result(float[] result) {
        float probability = result[0];
        int r = 0;
        for (int i = 0; i < result.length; i++) {
            if (probability < result[i]) {
                probability = result[i];
                r = i;
            }
        }
        return r;
    }

    private void checkPermission(){
        if(Build.VERSION.SDK_INT>=Build.VERSION_CODES.M){
            String[] permissions=new String[]{
              Manifest.permission.CAMERA,
              Manifest.permission.READ_EXTERNAL_STORAGE,
              Manifest.permission.WRITE_EXTERNAL_STORAGE
            };
            for (String permission:permissions){
                if (ContextCompat.checkSelfPermission(this,permission)!= PackageManager.PERMISSION_GRANTED){
                    ActivityCompat.requestPermissions(this,permissions,1);
                }
            }
        }
    }

    @Override
    public void onRequestPermissionsResult(int requestCode, @NonNull String[] permissions, @NonNull int[] grantResults) {
        super.onRequestPermissionsResult(requestCode, permissions, grantResults);
        switch (requestCode) {
            case 1:
                if (grantResults.length > 0) {
                    for (int i = 0; i < grantResults.length; i++) {

                        int grantResult = grantResults[i];
                        if (grantResult == PackageManager.PERMISSION_DENIED) {
                            String s = permissions[i];
                            Toast.makeText(this, s + " permission was denied", Toast.LENGTH_SHORT).show();
                        }
                    }
                }
                break;
        }
    }

    public  Bitmap getScaleBitmap(String testImage) {
        //BitmapFactory.Options类代表对Bitmap对象的属性设置
        BitmapFactory.Options opt = new BitmapFactory.Options();
        //是否只获取信息,不加载Bitmap
        opt.inJustDecodeBounds = true;
        //Bitmap的工厂类BitmapFactory提供了四类静态方法用于加载Bitmap对象:decodeFile、decodeResource、decodeStream、decodeByteArray。
        //分别代表从本地图片文件、项目资源文件、流对象(可以是网络输入流对象或本地文件输入流对象)、字节序列中加载一个Bitmap对象。
        AssetManager assetManager=this.getAssets();
        Bitmap bitmap=null;
        try {
            InputStream in=assetManager.open(testImage);
            bitmap=BitmapFactory.decodeStream(in);
        } catch (IOException e) {
            e.printStackTrace();
        }
        return bitmap;
    }
}

运行结果

在这里插入图片描述
在这里插入图片描述

参考:tensorflow lite

  • 0
    点赞
  • 11
    收藏
    觉得还不错? 一键收藏
  • 2
    评论
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值