前段时间训练了mnist手写数字识别的模型,学习后将其移植到Android端
我是参考的大佬https://puke3615.github.io/2017/08/02/Run-Mnist-On-Android/,https://github.com/wangtianrui/TFonAndroid的源码,有需要的的朋友可以去下载,这里是对他写的代码的分析和我自己的理解
注解ButterKnife学习:https://www.jianshu.com/p/952c6f5e8157
implementation 'com.jakewharton:butterknife:8.8.1'
手机上效果为:
移植到Android时要添加依赖文件:libandroid_tensorflow_inference_java.jar,和编译后的TensoFlow的so库,libtensorflow_inference.so,将其添加在lib文件夹中:
接下来将训练好的pb模型放入assets文件夹中
在build.gradle文件中添加:这个可以支持在手机中调试
testInstrumentationRunner "android.support.test.runner.AndroidJUnitRunner"
这里分享一下运行程序时遇到的坑:
出现了问题“Android-Device supports x86,but APK only supports armeabi-v7a,armeabi,x86_64”,使用模拟器不能运行,因为之前添加了支持tensorflow的so库和jar包,后来我在build文件中添加
multiDexEnabled true
ndk {
abiFilters "armeabi-v7a"
}
也还是没有用,后来看到大佬的代码仿佛才明白了一些东西。。。
在build中增加:
sourceSets {
main {
jniLibs.srcDirs = ['libs']
}
}
下面是实现过程:
定义mnist分类器:
private final String MODEL_PATH = "file:///android_asset/mnist.pb";//加载模型
public static final String INPUT_NAME = "input";//对应训练模型占位符x-input
public static final String KEEP_PROB_NAME = "keep_prob";
public static final String OUTPUT_NAME = "output";//训练模型时的占位符y_
//注意训练模型时一定要对值和标签设置name值,导入Android后要喂数据
//tensorflow依赖文件的类
private TensorFlowInferenceInterface inference;
//图片像素28*28
private final int width = 28;
private final int heifht = 28;
private float[] inputs = new float[width * heifht];
private int[] INPUT_SHAPE = new int[]{1, width * heifht};
//AssetManager :提供低级别的访问应用资源的API
//不同模型框架,训练模型输入的占位符不同,一定要一一对应;
//训练的数据集一定要对齐,resize与采用模型框架图像的大小一致,Android端调用接口,输入参数一定要和训练图像一致,否则会出现分类错误。
public MnistClassifier(AssetManager assetManager) {
this.inference = new TensorFlowInferenceInterface(assetManager, MODEL_PATH);//传入模型的路径
//模型使用阶段, 不需要进行dropout处理, 所以keep_prob直接为1.0
//dropout层:keep_prob训练时为0.5,测试时为1
inference.feed(KEEP_PROB_NAME, new float[]{1.0f}, 1);
}
public float[] getResult(float[] inputs) {
try {
this.inputs = inputs;
} catch (Exception e) {
e.printStackTrace();
}
//输出结果是十个数字的概率
float[] output = new float[10];
//填入Input数据
inference.feed(INPUT_NAME, inputs, 1, width * heifht);
//运行结果, 类似Python中的sess.run([outputs])
inference.run(new String[]{OUTPUT_NAME}, false);
inference.fetch(OUTPUT_NAME, output);
return output;
}
定义画板:
public class PrinterView extends View {
//画笔
private Paint paint;
//用来存储“路径”
private Path path;
//屏幕宽
private int width;
public PrinterView(Context context) {
super(context);
}
public PrinterView(Context context, @Nullable AttributeSet attrs) {
super(context, attrs);
setBackgroundColor(Color.WHITE);
paint = new Paint();
paint.setColor(Color.RED);
paint.setStrokeWidth(TypedValue.applyDimension(TypedValue.COMPLEX_UNIT_DIP, 20, getResources().getDisplayMetrics()));
paint.setStyle(Paint.Style.STROKE);
path = new Path();
int screenWidth = getResources().getDisplayMetrics().widthPixels;
width = MeasureSpec.makeMeasureSpec(screenWidth, MeasureSpec.EXACTLY);
}
/**
1.精确模式(MeasureSpec.EXACTLY)
在这种模式下,尺寸的值是多少,那么这个组件的长或宽就是多少。
2.最大模式(MeasureSpec.AT_MOST)
这个也就是父组件,能够给出的最大的空间,当前组件的长或宽最大只能为这么大,当然也可以比这个小。
3.未指定模式(MeasureSpec.UNSPECIFIED)
这个就是说,当前组件,可以随便用空间,不受限制。
*/
//画出手指滑动的轨迹
@Override
public boolean onTouchEvent(MotionEvent event) {
float x = event.getX();
float y = event.getY();
switch (event.getAction()) {
case MotionEvent.ACTION_DOWN:
//按下
path.moveTo(x, y);
break;
case MotionEvent.ACTION_MOVE:
path.lineTo(x, y);
break;
}
//刷新view
invalidate();
return true;
}
@Override
protected void onDraw(Canvas canvas) {
super.onDraw(canvas);
canvas.drawPath(path, paint);
}
@Override
protected void onMeasure(int widthMeasureSpec, int heightMeasureSpec) {
//定制画板的宽和高
super.onMeasure(width, width);
}
public void clean() {
path.reset();
invalidate();
}
public boolean isEmpty() {
return path.isEmpty();
}
//向外部提供读取画布数据的方法
/**
View组件显示的内容可以通过cache机制保存为bitmap, 使用到的api有
void setDrawingCacheEnabled(boolean flag),
Bitmap getDrawingCache(boolean autoScale),
void buildDrawingCache(boolean autoScale),
void destroyDrawingCache()
我们要获取它的cache先要通过setDrawingCacheEnable方法把cache开启,然后再调用getDrawingCache方法就可 以获得view的cache图片了。buildDrawingCache方法可以不用调用,因为调用getDrawingCache方法时,若果 cache没有建立,系统会自动调用buildDrawingCache方法生成cache。若果要更新cache, 必须要调用destoryDrawingCache方法把旧的cache销毁,才能建立新的。
当调用setDrawingCacheEnabled方法设置为false, 系统也会自动把原来的cache销毁。
ViewGroup在绘制子view时,而外提供了两个方法
void setChildrenDrawingCacheEnabled(boolean enabled)
setChildrenDrawnWithCacheEnabled(boolean enabled)
setChildrenDrawingCacheEnabled方法可以使viewgroup里所有的子view开启cache, setChildrenDrawnWithCacheEnabled使在绘制子view时,若该子view开启了cache, 则使用它的cache进行绘制,从而节省绘制时间。
获取cache通常会占用一定的内存,所以通常不需要的时候有必要对其进行清理,通过destroyDrawingCache或setDrawingCacheEnabled(false)实现。
*/
public float[] getData(int width, int height) {
float[] data = new float[height * width];
try {
//先让cache可以被读取(将View转化为图片都会使用cache)
setDrawingCacheEnabled(true);
setDrawingCacheQuality(View.DRAWING_CACHE_QUALITY_LOW);
Bitmap cache = getDrawingCache();
dealData(cache, data, width, height);
} finally {
setDrawingCacheEnabled(false);
}
return data;
}
private void dealData(Bitmap bm, float[] data, int newWidth, int newHeight) {
//获得bitmap的宽和高
int width = bm.getWidth();
int height = bm.getHeight();
//计算缩放比例
float scaleWidth = ((float) newWidth) / width;
float scaleHeight = ((float) newHeight) / height;
//取得想要缩放的matrix参数
Matrix matrix = new Matrix();
matrix.postScale(scaleWidth, scaleHeight);
//获得目标大小的图
Bitmap newBm = Bitmap.createBitmap(bm, 0, 0, width, height, matrix, true);
for (int y = 0; y < newHeight; y++) {
for (int x = 0; x < newWidth; x++) {
//获得每个点的像素值
int pixel = newBm.getPixel(x, y);
data[newWidth * y + x] = pixel == 0xffffffff ? 0 : 1;
}
}
}
}
识别逻辑:
@OnClick({R.id.printer_view, R.id.result_text_view, R.id.clean_button, R.id.detect_button})
public void onViewClicked(View view) {
switch (view.getId()) {
case R.id.clean_button:
printerView.clean();
resultTextView.setText(null);
break;
case R.id.detect_button:
if (printerView.isEmpty()) {
resultTextView.setText("画板为空");
break;
}
MnistClassifier mnistClassifier = new MnistClassifier(getAssets());
float[] result = mnistClassifier.getResult(printerView.getData(28, 28));
List<MnistItem> items = new ArrayList<>(10);
for (int i = 0; i < result.length; i++) {
items.add(new MnistItem(result[i], i));
}
Collections.sort(items);//选择概率最大的对应的值
StringBuilder builder = new StringBuilder();
for (int i = 0; i < 1 ; i++) {
MnistItem item = items.get(i);
builder.append((int)item.getIndex());
}
resultTextView.setText(builder.toString());
break;
}