手写字识别生成pb利用android,Android上运行手写数字识别模型

66b52468c121889b900d4956032f1009.png

8种机械键盘轴体对比

本人程序员,要买一个写代码的键盘,请问红轴和茶轴怎么选?

Github源码请移步本文底部。

模型导出pb文件

首先我们需要在我们的python代码中保存训练好的模型,save_path参数就传递**.pb,这里导出文件留给接下来使用1

2

3

4def (session, save_path):

out_graph_def = tf.graph_util.convert_variables_to_constants(session, session.graph_def, ["output"])

with tf.gfile.FastGFile(save_path, 'wb') as file:

file.write(out_graph_def.SerializeToString())

Android中通过JNI调用

Tensorflow与Android整合

封装输出数据解析逻辑

在手写数字识别模型中的输出是一个size为10的列表,列表元素的索引值对应输出的结果,列表元素对应输出的概率,例如输出是[0.2, 0.7, 0.01……],即表示有0.2的概率是0,0.7的概率是1,0.01的概率是2……

因此我们需要在输出中对数据按照概率进行降序排列,以便让结果一目了然。1

2

3

4

5

6

7

8

9

10

11

12

13

14

15

16

17

18

19

20

21

22

23

24

25

26

27

28

29

30

31

32

33

34

35

36

37

38

39

40

41

42

43

44

45

46

47

48

49

50

51

52

53* @author zijiao

* @version 17/8/2

*/

public class MnistData{

private final List items = new ArrayList<>(10);

public MnistData(float[] data){

for (int i = 0; i < data.length; i++) {

items.add(new Item(data[i], i));

}

Collections.sort(items);

}

public String top(int topSize){

StringBuilder builder = new StringBuilder();

for (int i = 0; i < topSize; i++) {

Item item = items.get(i);

builder.append(item.index)

.append(": ")

.append(String.format("%.1f%%", item.value * 100))

.append("n");

}

return builder.toString();

}

public String output(){

return String.valueOf(items.get(0).index);

}

public String toString(){

return output();

}

@SuppressWarnings("NullableProblems")

private static class Item implements Comparable{

final float value;

final float index;

private Item(float value, float index){

this.value = value;

this.index = index;

}

public int compareTo(Item o){

return value < o.value ? 1 : -1;

}

}

}

这时我们就能通过MnistData类的top方法得到概率最大的几个结果分别是什么。

构建数字分类器

这里通过TensorFlowInferenceInterface来调用模型,注释写得很清楚,值得注意的一点是,input和output的名称要和模型中的变量名称保持一致。1

2

3

4

5

6

7

8

9

10

11

12

13

14

15

16

17

18

19

20

21

22

23

24

25

26

27

28

29

30

31

32* @author zijiao

* @version 17/8/2

*/

public class MnistClassifier{

private final TensorFlowInferenceInterface inference;

public MnistClassifier(AssetManager assetManager){

inference = new TensorFlowInferenceInterface();

// 加载模型图

inference.initializeTensorFlow(assetManager, TF.MODEL);

// 模型使用阶段, 不需要进行dropout处理, 所以keep_prob直接为1.0

inference.fillNodeFloat(TF.KEEP_PROB_NAME, new int[]{1}, new float[]{1.0f});

}

public MnistData inference(float[] input){

if (input == null || input.length != 28 * 28) {

throw new RuntimeException("Input data is error.");

}

// 填入Input数据

inference.fillNodeFloat(TF.INPUT_NAME, TF.INPUT_TYPE, input);

// 运行结果, 类似Python中的sess.run([outputs])

inference.runInference(new String[]{TF.OUTPUT_NAME});

float[] output = new float[10];

// 取出结果集中我们需要的

inference.readNodeFloat(TF.OUTPUT_NAME, output);

// 将输出结果交给MnistData处理

return new MnistData(output);

}

}

添加画板

模型处理的逻辑已经写完了,接下来就是如何得到输入源了。由于是手写数字识别,所以接下来就要写画板类。这里只贴出关键代码部分(完整代码可以看本文底部的Github地址)。

手指滑动屏幕时画出手指滑动的轨迹1

2

3

4

5

6

7

8

9

10

11

12

13

14

15

16

17

18

19

20

21protected void onDraw(Canvas canvas){

super.onDraw(canvas);

canvas.drawPath(path, paint);

}

public boolean onTouchEvent(MotionEvent event){

float x = event.getX();

float y = event.getY();

switch (event.getAction()) {

case MotionEvent.ACTION_DOWN:

path.moveTo(x, y);

break;

case MotionEvent.ACTION_MOVE:

path.lineTo(x, y);

break;

}

invalidate();

return true;

}

向外部提供读取画布数据的方法1

2

3

4

5

6

7

8

9

10

11

12

13

14

15

16

17

18

19

20

21

22

23

24

25

26

27

28

29

30

31

32public float[] fetchData(int width, int height) {

float[] data = new float[height * width];

try {

setDrawingCacheEnabled(true);

setDrawingCacheQuality(View.DRAWING_CACHE_QUALITY_LOW);

Bitmap cache = getDrawingCache();

fillInputData(cache, data, width, height);

} finally {

setDrawingCacheEnabled(false);

}

return data;

}

private void fillInputData(Bitmap bm, float[] data, int newWidth, int newHeight){

// 获得图片的宽高

int width = bm.getWidth();

int height = bm.getHeight();

// 计算缩放比例

float scaleWidth = ((float) newWidth) / width;

float scaleHeight = ((float) newHeight) / height;

// 取得想要缩放的matrix参数

Matrix matrix = new Matrix();

matrix.postScale(scaleWidth, scaleHeight);

// 得到新的图片

Bitmap newbm = Bitmap.createBitmap(bm, 0, 0, width, height, matrix, true);

for (int y = 0; y < newHeight; y++) {

for (int x = 0; x < newWidth; x++) {

int pixel = newbm.getPixel(x, y);

data[newWidth * y + x] = pixel == 0xffffffff ? 0 : 1;

}

}

}

运行测试

布局代码就直接省略了,我们只需要在点击识别的时候,调用下面这段的识别逻辑即可。1

2

3

4

5

6

7

8

9// 识别

public void onInference(View view){

if (canvasView.isEmpty()) {

resultPanel.setText("画板为空");

return;

}

MnistData result = classifier.inference(canvasView.fetchData(28, 28));

resultPanel.setText(result.top(3));

}

最后附上运行效果图

91eca7725bcff81615f68b1a6ed593df.png

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值