手机调用TF模型的过程简介:
1、 保存训练完毕的TF模型
2、 在Android项目中导入TF模型、导入Android平台调用TF模型需要的jar包和so文件 (它们负责TF模型的解析和运算)
3、定义变量、存储数据,通过jar包提供的接口进行模型的调用
移植过程
我们以mnist数据集上自己训练的一个图像识别模型为例,进行讲解
一、 在使用python代码编写的TF模型定义中为模型的输入层和输出层Tensor Variable分别指定名字(通过形参 ‘name’)
X = tf.placeholder(tf.float32, shape = […], name=‘input’) //网络的输入
Y = tf.nn.softmax(tf.matmul(f, out_weights) + out_biases, name=’output’) //网络的输出
名字可以随便起,以方便好记为主,后面还会反复用到。我起的是input和output。
二、 将使用TensorFlow训练好的模型保存为.pb文件
在模型训练结束后的代码位置,添加下述两句代码,可将模型保存为.pb文件
output_graph_def = tf.graph_until.convert_variables_to_constants(session, session.graph_def, output_node_names=[‘output’])
//形参output_node_names用于指定输出的节点名称
with tf.gfile.FastGFile(model\mnist.pb, mode = ’wb’) as f:
f.write(output_graph_def.SerializeToString())
第一个参数用于指定输出的文件存放路径、文件名及格式。我把它放在与代码同级目录的model文件下,取名为mnist.pb
第二个参数 mode用于指定文件操作的模式,’wb’中w代表写文件,b代表将数据以二进制方式写入文件。
如果不指明‘b’,则默认会以文本txt方式写入文件。现在TF还不支持对文本格式.pb文件的解析,在调用时会出现报错。
注:
1)、不能使用 tf.train.write_graph()保存模型,因为它只是保存了模型的结构,并不保存训练完毕的参数值
2)、不能使用 tf.train.saver()保存模型,因为它只是保存了网络中的参数值,并不保存模型的结构。
很显然,我们需要的是既保存模型的结构,又保存模型中每个参数的值。以上两者皆不符合。
五、添加资源到项目
1) 将(二)步生成的.pb文件放入项目中
打开 Project view ,app/src/main/assets。
若不存在assets目录,右键main->new->folder->Assets Folder
2) 添加(三)步生成的jar包
打开Project view,将jar包拷贝到app->libs下
选中jar文件,右键 add as library
3) 添加(三)生成的so文件
打开 Project view,将.so文件拷贝到 app/src/main/jniLibs下(jniLibs文件夹若没有则新建)
如果我讲的不太明白的话,可自行谷歌搜索“如何在 Android studio中添加引用 jar文件和so文件”
六、创建接口,实现调用
1) 导入jar包和so文件
在需要调用模型的.Java文件中,导入jar包:
import org.tensorflow.contrib.android.TensorFlowInferenceInterface
在该java类定义的首行,导入so文件:
{
System.loadLibrary(“tensorflow_inference”)
}
2)定义变量及对象
private static final String MODEL_FILE = “file:///android_asset/mnist.pb” //模型存放路径
private static final String INPUT_NODE = “input”; //模型中输入变量的名称
private static final String INPUT_NODE = “output”; //模型中输出变量的名称
private static final int NUM_CLASSES = 10; //样本集的类别数量,mnist数据集对应10
private static final int HEIGHT = 24; //输入图片的像素高
private static final int WIDTH = 24; //输入图片的像素宽
private static final int CHANNEL = 3; //输入图片的通道数:RGB
private floats inputs = new float[HEIGHT*WIDTH*CHANNEL]; //用于存储的模型输入数据
private floats outputs = new float[NUM_CLASSES]; //用于存储模型的输出数据
2)Tensorflow 接口初始化
private TensorFlowInferenceInterface inferenceInterface = new TensorFlowInferenceInterface(); //接口定义
inferenceInterface.initializeTensorFlow(getAssets(), MODEL_FILE); //接口初始化
在完成上述两步之后,就可以反复调用模型。
在每次调用前,先将待输入的数据按顺序存放进 inputs 变量中,然后执行下述三个语句。
3)TF模型的调用
inferenceInterface.fillNodeFloat(INPUT_NODE, new int[]{1, HEIGHT, WIDTH, CHANNEL}, inputs); //送入输入数据
inferenceInterface.runInference(new String[]{OUTPUT_NODE}); //进行模型的推理
inferenceInterface.readNodeFloat(OUTPUT_NODE, outputs); //获取输出数据
实现Demo
分析源码
在Android中 native修饰的就是使用动态链接库中的接口,对于这个图片分类的demo,看了这写Java代码,我们可以找到tensorflow的3个接口如下:
// load the tensorflow
public native int initializeTensorFlow(
AssetManager assetManager,
String model,
String labels,
int numClasses,
int inputSize,
int imageMean,
float imageStd,
String inputName,
String outputName);
// classify the image by input the bitmap
private native String classifyImageBmp(Bitmap bitmap);
// classify the image by input the rgb
private native String classifyImageRgb(int[] output, int width, int height);
因此我们只需要学会使用这三个函数就能够将tensorflow移植我们的项目中了,对了,下面这条语句是载入动态链接库
我们到TensorFlowImageListener中找到了这几个函数的使用,因此,在使用时我们首先需要创建TensorFlowClassifier对象
private final TensorFlowClassifier tensorflow = new TensorFlowClassifier();
然后我们需要载入tensorflow模型,载入时需要以下几个参数,注意我在TensorFlowImageClassifier中已经将我暂时不需要的参数删除了
private static final int NUM_CLASSES = 1001;
private static final int INPUT_SIZE = 224;
private static final int IMAGE_MEAN = 117;
private static final float IMAGE_STD = 1;
private static final String INPUT_NAME = "input:0";
private static final String OUTPUT_NAME = "output:0";
private static final String MODEL_FILE = "file:///android_asset/tensorflow_inception_graph.pb";
private static final String LABEL_FILE =
"file:///android_asset/imagenet_comp_graph_label_strings.txt
";
不要忘了载入模型的第一个参数是assetManager,这个参数表示模型训练的数据结果(pb&&txt)文件的位置,如果为空的话会报出异常,上面的几个参数,基本看一下名字就知道是啥了,比如输入的名,图片的大小224*224等。模型初始化完成后就要对图片分类了,我们可以使用private native String classifyImageBmp(Bitmap bitmap);直接传入图片的bitmap位图,并且将图片大小调整为INPUT_SIZE即可
问题2:
应用在载入模型过程中闪退
看看你的assets目录位置对吗,也就是第一个参数,这个错了是无法载入模型的哦
在项目中使用ternsorflow
我直接将上面的demo中的主Activity清空,重新写了一个Activity,这个应用打开手机相册中的一张照片,然后将图片显示在界面上并且在最上面显示这个物品最可能的名字(使用谷歌的训练数据),这个项目依赖上面demo中的TensorflowClassifier类
主Activity代码:
package org.tensorflow.demo;
import java.io.FileNotFoundException;
import java.util.List;
import android.app.Activity;
import android.content.ContentResolver;
import android.content.Intent;
import android.content.res.AssetManager;
import android.graphics.Bitmap;
import android.graphics.BitmapFactory;
import android.graphics.Matrix;
import android.net.Uri;
import android.os.Bundle;
import android.util.Log;
import android.view.View;
import android.widget.Button;
import android.widget.ImageView;
import android.widget.TextView;
public class CameraActivity extends Activity {
/** Called when the activity is first created. */
private static final String MODEL_FILE = "file:///android_asset/tensorflow_inception_graph.pb";
private static final String LABEL_FILE = "file:///android_asset/imagenet_comp_graph_label_strings.txt";
private static final int NUM_CLASSES = 1001;
private static final int INPUT_SIZE = 224;
private static final int IMAGE_MEAN = 117;
private static final float IMAGE_STD = 1;
private static final String INPUT_NAME = "input:0";
private static final String OUTPUT_NAME = "output:0";
private final TensorFlowClassifier tensorflow = new TensorFlowClassifier();
private TextView mResultText;
@Override
public void onCreate(Bundle savedInstanceState) {
super.onCreate(savedInstanceState);
setContentView(R.layout.activity_camera);
// test1 load tensorflow
final AssetManager assetManager = getAssets();
tensorflow.initializeTensorFlow(
assetManager, MODEL_FILE, LABEL_FILE, NUM_CLASSES, INPUT_SIZE, IMAGE_MEAN, IMAGE_STD,
INPUT_NAME, OUTPUT_NAME);
// test1 end
Button button = (Button)findViewById(R.id.b01);
button.setText("选择图片");
button.setOnClickListener(new Button.OnClickListener(){
@Override
public void onClick(View v) {
Intent intent = new Intent();
/* 开启Pictures画面Type设定为image */
intent.setType("image/*");
/* 使用Intent.ACTION_GET_CONTENT这个Action */
intent.setAction(Intent.ACTION_GET_CONTENT);
/* 取得相片后返回本画面 */
startActivityForResult(intent, 1);
}
});
}
@Override
protected void onActivityResult(int requestCode, int resultCode, Intent data) {
if (resultCode == RESULT_OK) {
Uri uri = data.getData();
Log.e("uri", uri.toString());
ContentResolver cr = this.getContentResolver();
try {
Bitmap bitmap = BitmapFactory.decodeStream(cr.openInputStream(uri));
dealPics(bitmap);
} catch (FileNotFoundException e) {
Log.e("Exception", e.getMessage(),e);
}
}
super.onActivityResult(requestCode, resultCode, data);
}
private void dealPics(Bitmap bitmap) {
ImageView imageView = (ImageView) findViewById(R.id.iv01);
/* 将Bitmap设定到ImageView */
int width = bitmap.getWidth();
int height = bitmap.getHeight();
System.out.println(width + "&&" + height);
float scaleWidth = ((float)INPUT_SIZE) / width;
float scaleHeight = ((float) INPUT_SIZE) / height;
Matrix matrix = new Matrix();
matrix.postScale(scaleWidth, scaleHeight);
Bitmap newbm = Bitmap.createBitmap(bitmap, 0, 0, width, height, matrix, true);
imageView.setImageBitmap(newbm);
final List<Classifier.Recognition> results = tensorflow.recognizeImage(newbm);
for (final Classifier.Recognition result : results) {
System.out.println("Result: " + result.getTitle());
}
mResultText = (TextView)findViewById(R.id.t01);
mResultText.setText("Detected = " + results.get(0).getTitle());
System.out.println(newbm.getWidth() + "&&" + newbm.getHeight());
}
}
JNI 代码
#include "tensorflow/examples/android/jni/tensorflow_jni.h"
#include <android/asset_manager.h>
#include <android/asset_manager_jni.h>
#include <android/bitmap.h>
#include <jni.h>
#include <pthread.h>
#include <sys/stat.h>
#include <unistd.h>
#include <queue>
#include <sstream>
#include <string>
#include "tensorflow/core/framework/step_stats.pb.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/public/session.h"
#include "tensorflow/core/util/stat_summarizer.h"
#include "tensorflow/examples/android/jni/jni_utils.h"
using namespace tensorflow;
// Global variables that holds the TensorFlow classifier.
static std::unique_ptr<tensorflow::Session> session;
static std::vector<std::string> g_label_strings;
static bool g_compute_graph_initialized = false;
// static mutex g_compute_graph_mutex(base::LINKER_INITIALIZED);
static int g_tensorflow_input_size; // The image size for the model input.
static int g_image_mean; // The image mean.
static float g_image_std; // The scale value for the input image.
static std::unique_ptr<std::string> g_input_name;
static std::unique_ptr<std::string> g_output_name;
static std::unique_ptr<StatSummarizer> g_stats;
// For basic benchmarking.
static int g_num_runs = 0;
static int64 g_timing_total_us = 0;
static Stat<int64> g_frequency_start;
static Stat<int64> g_frequency_end;
#ifdef LOG_DETAILED_STATS
static const bool kLogDetailedStats = true;
#else
static const bool kLogDetailedStats = false;
#endif
// Improve benchmarking by limiting runs to predefined amount.
// 0 (default) denotes infinite runs.
#ifndef MAX_NUM_RUNS
#define MAX_NUM_RUNS 0
#endif
#ifdef SAVE_STEP_STATS
static const bool kSaveStepStats = true;
#else
static const bool kSaveStepStats = false;
#endif
inline static int64 CurrentThreadTimeUs() {
struct timeval tv;
gettimeofday(&tv, NULL);
return tv.tv_sec * 1000000 + tv.tv_usec;
}
JNIEXPORT jint JNICALL TENSORFLOW_METHOD(initializeTensorFlow)(
JNIEnv* env, jobject thiz, jobject java_asset_manager, jstring model,
jstring labels, jint num_classes, jint model_input_size, jint image_mean,
jfloat image_std, jstring input_name, jstring output_name) {
g_num_runs = 0;
g_timing_total_us = 0;
g_frequency_start.Reset();
g_frequency_end.Reset();
// MutexLock input_lock(&g_compute_graph_mutex);
if (g_compute_graph_initialized) {
LOG(INFO) << "Compute graph already loaded. skipping.";
return 0;
}
const int64 start_time = CurrentThreadTimeUs();
const char* const model_cstr = env->GetStringUTFChars(model, NULL);
const char* const labels_cstr = env->GetStringUTFChars(labels, NULL);
g_tensorflow_input_size = model_input_size;
g_image_mean = image_mean;
g_image_std = image_std;
g_input_name.reset(new std::string(env->GetStringUTFChars(input_name, NULL)));
g_output_name.reset(
new std::string(env->GetStringUTFChars(output_name, NULL)));
LOG(INFO) << "Loading TensorFlow.";
LOG(INFO) << "Making new SessionOptions.";
tensorflow::SessionOptions options;
tensorflow::ConfigProto& config = options.config;
LOG(INFO) << "Got config, " << config.device_count_size() << " devices";
session.reset(tensorflow::NewSession(options));
LOG(INFO) << "Session created.";
tensorflow::GraphDef tensorflow_graph;
LOG(INFO) << "Graph created.";
AAssetManager* const asset_manager =
AAssetManager_fromJava(env, java_asset_manager);
LOG(INFO) << "Acquired AssetManager.";
LOG(INFO) << "Reading file to proto: " << model_cstr;
ReadFileToProto(asset_manager, model_cstr, &tensorflow_graph);
g_stats.reset(new StatSummarizer(tensorflow_graph));
LOG(INFO) << "Creating session.";
tensorflow::Status s = session->Create(tensorflow_graph);
if (!s.ok()) {
LOG(FATAL) << "Could not create TensorFlow Graph: " << s;
}
// Clear the proto to save memory space.
tensorflow_graph.Clear();
LOG(INFO) << "TensorFlow graph loaded from: " << model_cstr;
// Read the label list
ReadFileToVector(asset_manager, labels_cstr, &g_label_strings);
LOG(INFO) << g_label_strings.size()
<< " label strings loaded from: " << labels_cstr;
g_compute_graph_initialized = true;
const int64 end_time = CurrentThreadTimeUs();
LOG(INFO) << "Initialization done in " << (end_time - start_time) / 1000
<< "ms";
return 0;
}
namespace {
typedef struct {
uint8 red;
uint8 green;
uint8 blue;
uint8 alpha;
} RGBA;
} // namespace
// Returns the top N confidence values over threshold in the provided vector,
// sorted by confidence in descending order.
static void GetTopN(
const Eigen::TensorMap<Eigen::Tensor<float, 1, Eigen::RowMajor>,
Eigen::Aligned>& prediction,
const int num_results, const float threshold,
std::vector<std::pair<float, int> >* top_results) {
// Will contain top N results in ascending order.
std::priority_queue<std::pair<float, int>,
std::vector<std::pair<float, int> >,
std::greater<std::pair<float, int> > >
top_result_pq;
const int count = prediction.size();
for (int i = 0; i < count; ++i) {
const float value = prediction(i);
// Only add it if it beats the threshold and has a chance at being in
// the top N.
if (value < threshold) {
continue;
}
top_result_pq.push(std::pair<float, int>(value, i));
// If at capacity, kick the smallest value out.
if (top_result_pq.size() > num_results) {
top_result_pq.pop();
}
}
// Copy to output vector and reverse into descending order.
while (!top_result_pq.empty()) {
top_results->push_back(top_result_pq.top());
top_result_pq.pop();
}
std::reverse(top_results->begin(), top_results->end());
}
static int64 GetCpuSpeed() {
string scaling_contents;
ReadFileToString(nullptr,
"/sys/devices/system/cpu/cpu0/cpufreq/scaling_cur_freq",
&scaling_contents);
std::stringstream ss(scaling_contents);
int64 result;
ss >> result;
return result;
}
static std::string ClassifyImage(const RGBA* const bitmap_src) {
// Force the app to quit if we've reached our run quota, to make
// benchmarks more reproducible.
if (MAX_NUM_RUNS > 0 && g_num_runs >= MAX_NUM_RUNS) {
LOG(INFO) << "Benchmark complete. "
<< (g_timing_total_us / g_num_runs / 1000) << "ms/run avg over "
<< g_num_runs << " runs.";
LOG(INFO) << "";
exit(0);
}
++g_num_runs;
// Create input tensor
tensorflow::Tensor input_tensor(
tensorflow::DT_FLOAT,
tensorflow::TensorShape(
{1, g_tensorflow_input_size, g_tensorflow_input_size, 3}));
auto input_tensor_mapped = input_tensor.tensor<float, 4>();
LOG(INFO) << "TensorFlow: Copying Data.";
for (int i = 0; i < g_tensorflow_input_size; ++i) {
const RGBA* src = bitmap_src + i * g_tensorflow_input_size;
for (int j = 0; j < g_tensorflow_input_size; ++j) {
// Copy 3 values
input_tensor_mapped(0, i, j, 0) =
(static_cast<float>(src->red) - g_image_mean) / g_image_std;
input_tensor_mapped(0, i, j, 1) =
(static_cast<float>(src->green) - g_image_mean) / g_image_std;
input_tensor_mapped(0, i, j, 2) =
(static_cast<float>(src->blue) - g_image_mean) / g_image_std;
++src;
}
}
std::vector<std::pair<std::string, tensorflow::Tensor> > input_tensors(
{{*g_input_name, input_tensor}});
VLOG(0) << "Start computing.";
std::vector<tensorflow::Tensor> output_tensors;
std::vector<std::string> output_names({*g_output_name});
tensorflow::Status s;
int64 start_time, end_time;
if (kLogDetailedStats || kSaveStepStats) {
RunOptions run_options;
run_options.set_trace_level(RunOptions::FULL_TRACE);
RunMetadata run_metadata;
g_frequency_start.UpdateStat(GetCpuSpeed());
start_time = CurrentThreadTimeUs();
s = session->Run(run_options, input_tensors, output_names, {},
&output_tensors, &run_metadata);
end_time = CurrentThreadTimeUs();
g_frequency_end.UpdateStat(GetCpuSpeed());
assert(run_metadata.has_step_stats());
const StepStats& stats = run_metadata.step_stats();
if (kLogDetailedStats) {
LOG(INFO) << "CPU frequency start: " << g_frequency_start;
LOG(INFO) << "CPU frequency end: " << g_frequency_end;
g_stats->ProcessStepStats(stats);
g_stats->PrintStepStats();
}
if (kSaveStepStats) {
mkdir("/sdcard/tf/", 0755);
const string filename =
strings::Printf("/sdcard/tf/stepstats%05d.pb", g_num_runs);
WriteProtoToFile(filename.c_str(), stats);
}
} else {
start_time = CurrentThreadTimeUs();
s = session->Run(input_tensors, output_names, {}, &output_tensors);
end_time = CurrentThreadTimeUs();
}
const int64 elapsed_time_inf = end_time - start_time;
g_timing_total_us += elapsed_time_inf;
VLOG(0) << "End computing. Ran in " << elapsed_time_inf / 1000 << "ms ("
<< (g_timing_total_us / g_num_runs / 1000) << "ms avg over "
<< g_num_runs << " runs)";
if (!s.ok()) {
LOG(FATAL) << "Error during inference: " << s;
}
VLOG(0) << "Reading from layer " << output_names[0];
tensorflow::Tensor* output = &output_tensors[0];
const int kNumResults = 5;
const float kThreshold = 0.1f;
std::vector<std::pair<float, int> > top_results;
GetTopN(output->flat<float>(), kNumResults, kThreshold, &top_results);
std::stringstream ss;
ss.precision(3);
for (const auto& result : top_results) {
const float confidence = result.first;
const int index = result.second;
ss << index << " " << confidence << " ";
// Write out the result as a string
if (index < g_label_strings.size()) {
// just for safety: theoretically, the output is under 1000 unless there
// is some numerical issues leading to a wrong prediction.
ss << g_label_strings[index];
} else {
ss << "Prediction: " << index;
}
ss << "\n";
}
LOG(INFO) << "Predictions: " << ss.str();
return ss.str();
}
JNIEXPORT jstring JNICALL TENSORFLOW_METHOD(classifyImageRgb)(
JNIEnv* env, jobject thiz, jintArray image, jint width, jint height) {
// Copy image into currFrame.
jboolean iCopied = JNI_FALSE;
jint* pixels = env->GetIntArrayElements(image, &iCopied);
std::string result = ClassifyImage(reinterpret_cast<const RGBA*>(pixels));
env->ReleaseIntArrayElements(image, pixels, JNI_ABORT);
return env->NewStringUTF(result.c_str());
}
JNIEXPORT jstring JNICALL TENSORFLOW_METHOD(classifyImageBmp)(JNIEnv* env,
jobject thiz,
jobject bitmap) {
// Obtains the bitmap information.
AndroidBitmapInfo info;
CHECK_EQ(AndroidBitmap_getInfo(env, bitmap, &info),
ANDROID_BITMAP_RESULT_SUCCESS);
void* pixels;
CHECK_EQ(AndroidBitmap_lockPixels(env, bitmap, &pixels),
ANDROID_BITMAP_RESULT_SUCCESS);
LOG(INFO) << "Image dimensions: " << info.width << "x" << info.height
<< " stride: " << info.stride;
// TODO(andrewharp): deal with other formats if necessary.
if (info.format != ANDROID_BITMAP_FORMAT_RGBA_8888) {
LOG(FATAL) << "Only RGBA_8888 Bitmaps are supported.";
}
std::string result = ClassifyImage(static_cast<const RGBA*>(pixels));
// Finally, unlock the pixels
CHECK_EQ(AndroidBitmap_unlockPixels(env, bitmap),
ANDROID_BITMAP_RESULT_SUCCESS);
return env->NewStringUTF(result.c_str());
}