手写字识别生成pb利用android,修改tensorflow_android_demo以识别手写数字

1.修改了tensorflow_jni.cc。mnist数据库采用,灰度图像,并且用二维矩阵存储数据[batch, gray]。然而,从摄像头读取的rgb图像,用四维矩阵存储[batch, height, width, channel]。所以,以下代码就完成了中间转换工作。

// Create input tensor

tensorflow::Tensor input_tensor(

tensorflow::DT_FLOAT,

tensorflow::TensorShape({

1, g_tensorflow_input_size * g_tensorflow_input_size}));

//i guess 2 represent the dimension of input_tensor.TensorShape

auto input_tensor_mapped = input_tensor.tensor();

LOG(INFO) << "Tensorflow: Copying Data.";

for (int i = 0; i < g_tensorflow_input_size; ++i) {

const RGBA* src = bitmap_src + i * g_tensorflow_input_size;

for (int j = 0; j < g_tensorflow_input_size; ++j) {

// Copy 3 values

float red =

static_cast(src->red) - g_image_mean;

float green =

static_cast(src->green) - g_image_mean;

float blue =

static_cast(src->blue) - g_image_mean;

++src;

// my added code

float gray = red * 0.299 + green * 0.587 + blue * 0.114;

input_tensor_mapped(0, i * g_tensorflow_input_size + j) = gray;

}

}

2.记得修改TensorflowImageListener.java以符合mnist的图片维度啊(28 * 28 = 784)

private static final int NUM_CLASSES = 10;

private static final int INPUT_SIZE = 28;

3.重新制作了pb模型,终于不报错了(泪

from __future__ import absolute_import

from __future__ import division

from __future__ import print_function

# Import data

import sys

sys.path.append('/home/houn/tensorflow/tensorflow-r0.9/tensorflow/python/tools')

import os

import tensorflow as tf

from tensorflow.examples.tutorials.mnist import input_data

#from tensorflow.python.tools import freeze_graph

import freeze_graph

flags = tf.app.flags

FLAGS = flags.FLAGS

flags.DEFINE_string('data_dir', '/tmp/data/', 'Directory for storing data')

mnist = input_data.read_data_sets(FLAGS.data_dir, one_hot=True)

checkpoint_prefix = os.path.join('/home/houn/tensorflow/pbMake', "saved_checkpoint")

checkpoint_state_name = "checkpoint_state"

input_graph_name = "input_graph.pb"

output_graph_name = "output_graph.pb"

first_graph = tf.Graph()

with first_graph.as_default():

with tf.Session() as sess:

print("# build graph and run")

x = tf.placeholder(tf.float32, shape=[None, 784], name="input")

y_ = tf.placeholder(tf.float32, shape=[None, 10], name="y_")

W = tf.Variable(tf.zeros([784,10]), name = "W")

b = tf.Variable(tf.zeros([10]), name = "b")

sess.run(tf.initialize_all_variables())

y = tf.nn.softmax(tf.matmul(x,W) + b, name="output")

cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y), reduction_indices=[1]))

train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)

for i in range(1000):

batch = mnist.train.next_batch(50)

train_step.run(feed_dict={x: batch[0], y_: batch[1]})

correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(y_,1))

accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

print(accuracy.eval(feed_dict={x: mnist.test.images, y_: mnist.test.labels}))

#save graph and variables

saver = tf.train.Saver()

saver.save(sess, checkpoint_prefix, global_step=0,

latest_filename=checkpoint_state_name)

tf.train.write_graph(sess.graph.as_graph_def(), '/home/houn/tensorflow/pbMake',

input_graph_name, False)

# We save out the graph to disk, and then call the const conversion

# routine.

input_graph_path = os.path.join('/home/houn/tensorflow/pbMake', input_graph_name)

input_saver_def_path = ""

input_binary = True

input_checkpoint_path = checkpoint_prefix + "-0"

output_node_names = "output"

restore_op_name = "save/restore_all"

filename_tensor_name = "save/Const:0"

output_graph_path = os.path.join('/home/houn/tensorflow/pbMake', output_graph_name)

clear_devices = False

freeze_graph.freeze_graph(input_graph_path, input_saver_def_path,

input_binary, input_checkpoint_path,

output_node_names, restore_op_name,

filename_tensor_name, output_graph_path,

clear_devices, "")

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值