实现过程
首先在pycharm中利用tensorflow框架训练手写字识别神经网络模型,然后保存为.pd文件,再将该文件转为.tflite格式,作为app的模型,基于该模型实现app的功能。
各版本说明
python 3
tensorflow 2.3.0
Android studio 4
基于mnist数据集的手写字神经网络训练
import tensorflow as tf
from tensorflow import keras
num_epochs = 10
batch_size = 50
learning_rate = 0.001
mnist = keras.datasets.mnist
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()
train_images = train_images / 255.0
test_images = test_images / 255.0
model = keras.Sequential([
keras.layers.Flatten(),
keras.layers.Dense(100, activation=tf.nn.relu),
keras.layers.Dense(10),
keras.layers.Softmax()
])
model.compile(
optimizer=keras.optimizers.Adam(learning_rate=0.001),
loss=keras.losses.sparse_categorical_crossentropy,
metrics=[keras.metrics.sparse_categorical_accuracy]
)
model.fit(train_images, train_labels, epochs=num_epochs, batch_size=batch_size)
test_loss, test_acc = model.evaluate(test_images, test_labels)
print(test_loss)
print('\nTest accuracy:', test_acc)
model.save('saved/1')
保存为.tflite文件
我之前看到的参考案例一般在cmd中实现文件转换,但是我在尝试过程中感觉比较麻烦,所以就直接在pycharm中进行转换
import tensorflow as tf
saved_model_path = r'C:\Users\49103\PycharmProjects\deepL1\saved\1'
converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_path)
# converter=tf.lite.TFLiteConverter.from_saved_model(saved_model_dir=in_path,input_arrays=[input_tensor_name],output_arrays=[class_tensor_name])
tflite_model = converter.convert()
with open('model.tflite', 'wb') as f:
f.write(tflite_model)
app项目结构
从mnist数据集中解压了图片并选取一张作为测试图片,与tflite模型一起放在assets资源文件目录下作为app功能测试图片。
Android studio程序
public class MainActivity extends AppCompatActivity {
private String mModelName="model";
private Button add_image,load_model;
private TextView result_text;
private ImageView show_image;
private boolean load_result=false;
private Interpreter tflite=null;
private int[] ddims={1,1,28,28};
private static String mTestImage="mnist_test_3.png";
@Override
protected void onCreate(Bundle savedInstanceState) {
super.onCreate(savedInstanceState);
setContentView(R.layout.activity_main);
checkPermission();
init();
}
private void init() {
add_image=findViewById(R.id.add_image);
load_model=findViewById(R.id.load_model);
result_text=findViewById(R.id.result_text);
show_image=findViewById(R.id.show_image);
load_model.setOnClickListener((view)-> {
try {
tflite=new Interpreter(loadModelFile(MainActivity.this));
Toast.makeText(MainActivity.this,"load model success",Toast.LENGTH_SHORT).show();
tflite.setNumThreads(1);
load_result=true;
}catch(IOException e){
Toast.makeText(MainActivity.this,"load model false",Toast.LENGTH_SHORT).show();
load_result=false;
e.printStackTrace();
}
});
AssetManager assetManager=this.getAssets();
add_image.setOnClickListener((view)->{
if(!load_result){
Toast.makeText(MainActivity.this,"never load model",Toast.LENGTH_SHORT).show();
return;
}
try {
InputStream inputStream=assetManager.open(mTestImage);
Bitmap bitmap=BitmapFactory.decodeStream(inputStream);
show_image.setImageBitmap(bitmap);
} catch (IOException e) {
e.printStackTrace();
}
predict_image(mTestImage);
});
}
/** Memory-map the model file in Assets. */
private MappedByteBuffer loadModelFile(Activity activity) throws IOException {
AssetFileDescriptor fileDescriptor = activity.getAssets().openFd(mModelName+".tflite");
FileInputStream inputStream = new FileInputStream(fileDescriptor.getFileDescriptor());
FileChannel fileChannel = inputStream.getChannel();
long startOffset = fileDescriptor.getStartOffset();
long declaredLength = fileDescriptor.getDeclaredLength();
return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength);
}
private void predict_image(String image_path) {
// picture to float array
Bitmap bmp = getScaleBitmap(image_path);
ByteBuffer inputData = getScaledMatrix(bmp, ddims);
try {
// Data format conversion takes too long
// Log.d("inputData", Arrays.toString(inputData));
float[][] labelProbArray = new float[1][10];
long start = System.currentTimeMillis();
// get predict result
tflite.run(inputData, labelProbArray);
long end = System.currentTimeMillis();
long time = end - start;
float[] results = new float[labelProbArray[0].length];
//System.arraycopy把一个数组中某一段字节数据放到另一个数组中
System.arraycopy(labelProbArray[0], 0, results, 0, labelProbArray[0].length);
// show predict result and time
int r = get_max_result(results);
String show_text = "result:" + r + "\nprobability:" + results[r] + "\ntime:" + time + "ms";
result_text.setText(show_text);
} catch (Exception e) {
e.printStackTrace();
}
}
public static ByteBuffer getScaledMatrix(Bitmap bitmap, int[] ddims) {
//基于新分配的内存块创建直接字节缓冲区。
ByteBuffer imgData = ByteBuffer.allocateDirect(ddims[0] * ddims[1] * ddims[2] * ddims[3] * 4);
//order:设置此缓冲区的字节顺序。ByteOrder.nativeOrder():返回当前平台字节顺序。
imgData.order(ByteOrder.nativeOrder());
// get image pixel
int[] pixels = new int[ddims[2] * ddims[3]];
//从当前存在的位图,按一定的比例创建一个新的位图。
Bitmap bm = Bitmap.createScaledBitmap(bitmap, ddims[2], ddims[3], false);
bm.getPixels(pixels, 0, ddims[2], 0, 0, ddims[2], ddims[3]);
int pixel = 0;
for (int i = 0; i < ddims[2]; ++i) {
for (int j = 0; j < ddims[3]; ++j) {
final int val = pixels[pixel++];
imgData.putFloat((((val & 0xFF) - 0) / 255.0f));
}
}
if (bm.isRecycled()) {
bm.recycle();
}
return imgData;
}
private int get_max_result(float[] result) {
float probability = result[0];
int r = 0;
for (int i = 0; i < result.length; i++) {
if (probability < result[i]) {
probability = result[i];
r = i;
}
}
return r;
}
private void checkPermission(){
if(Build.VERSION.SDK_INT>=Build.VERSION_CODES.M){
String[] permissions=new String[]{
Manifest.permission.CAMERA,
Manifest.permission.READ_EXTERNAL_STORAGE,
Manifest.permission.WRITE_EXTERNAL_STORAGE
};
for (String permission:permissions){
if (ContextCompat.checkSelfPermission(this,permission)!= PackageManager.PERMISSION_GRANTED){
ActivityCompat.requestPermissions(this,permissions,1);
}
}
}
}
@Override
public void onRequestPermissionsResult(int requestCode, @NonNull String[] permissions, @NonNull int[] grantResults) {
super.onRequestPermissionsResult(requestCode, permissions, grantResults);
switch (requestCode) {
case 1:
if (grantResults.length > 0) {
for (int i = 0; i < grantResults.length; i++) {
int grantResult = grantResults[i];
if (grantResult == PackageManager.PERMISSION_DENIED) {
String s = permissions[i];
Toast.makeText(this, s + " permission was denied", Toast.LENGTH_SHORT).show();
}
}
}
break;
}
}
public Bitmap getScaleBitmap(String testImage) {
//BitmapFactory.Options类代表对Bitmap对象的属性设置
BitmapFactory.Options opt = new BitmapFactory.Options();
//是否只获取信息,不加载Bitmap
opt.inJustDecodeBounds = true;
//Bitmap的工厂类BitmapFactory提供了四类静态方法用于加载Bitmap对象:decodeFile、decodeResource、decodeStream、decodeByteArray。
//分别代表从本地图片文件、项目资源文件、流对象(可以是网络输入流对象或本地文件输入流对象)、字节序列中加载一个Bitmap对象。
AssetManager assetManager=this.getAssets();
Bitmap bitmap=null;
try {
InputStream in=assetManager.open(testImage);
bitmap=BitmapFactory.decodeStream(in);
} catch (IOException e) {
e.printStackTrace();
}
return bitmap;
}
}
运行结果