转自http://www.zhimengzhe.com/Androidkaifa/298339.html
Tensorflow实现Android移动端的模型搭建
概述 随着深度学习的广泛应用和Tensoflow的开源,移动端的模型应用层出不穷。本文介绍了笔者在搭建过程中的一些心得,希望可以帮助到你们。 Mac端Tensorflow CPU版本的安装 如果你现在用的没有太好的GPU,可以安装CPU only的Tensorflow。Linux、Mac系统可以安装Tensorflow的python2和python3版本,Windows系统仅支持python3版本。 安装Tensorflow的依赖库bazel,这个后面要用来生成Tensorflow支持Android的jar包和so库。Mac下用brew安装命令:brew install bazel,或者根据bazel官方文档安装相应的版本; 用pip安装tensorflow的CPU only版本:pip install tensorflow; 验证Tensorflow安装是否成功:
import tensorflowas tf
#显示当前Tensorflow版本号
tf.__version__
生成jar包和so库 将github上的tensorflow下载到本地修改tensorflow目录下的WORKSPACE,将其中的sdk和ndk路径改为本地对应路径,其中的sdk的版本号要≥23,ndk的版本号建议是12b(高版本的ndk在用bazel编译时会出现一些问题),build_tools_version版本根据自己实际情况更改:
根据下面的指令生成jar包和so库
参考链接:
https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/android
生成so库的命令:可以选择cpu版本号
bazel build -c opt//tensorflow/contrib/android:libtensorflow_inference.so \
--crosstool_top=//external:android/crosstool\
--host_crosstool_top=@bazel_tools//tools/cpp:toolchain\
--cpu=armeabi-v7a
so库的位置:
bazel-bin/tensorflow/contrib/android/libtensorflow_inference.so
生成jar包的命令:
bazel build//tensorflow/contrib/android:android_tensorflow_inference_java
jar包的位置:
bazel-bin/tensorflow/contrib/android/libandroid_tensorflow_inference_java.jar
4. so库也可以选择Tensorflow官方提供的现成的文件,参考链接:
http://ci.tensorflow.org/view/Nightly/job/nightly-android/lastSuccessfulBuild/artifact/out/native/libtensorflow_inference.so/
Android端的搭建 将jar包放入app->libs目录下,并在build.gradle上添加依赖compilefiles('libs/libandroid_tensorflow_inference_java.jar'); 在src->main目录下新建文件夹jniLibs并将生成的so库放在该目录下; 将PC端训练的模型保存为pb模型
output_graph_def= \
graph_util.convert_variables_to_constants(sess,\
sess.graph_def,output_node_names=['output'])
withtf.gfile.FastGFile("path/to/xxx.pb","wb") as f:
f.write(output_graph_def.SerializeToString())
4.将pb文件放在src->main->assets目录下;
5. 下面我们用Android代码解释一下Tensorflow在Android端的搭建
public classTensorFlowAudioClassifier implements Classifier{
private static final String TAG ="TensorFlowAudioClassifier";
// Only return this many results with atleast this confidence.
private static final int MAX_RESULTS = 3;
private static final float THRESHOLD =0.0f;
// Config values.
//输入节点的名称(不带后面的':0',只是input的名称,如'input')
private String inputName;
//输出节点的名称(通输入节点名称一样)
private String outputName;
//输入矩阵的大小(因为一般是方形矩阵,这儿是方形矩阵的size)
private int inputSize;
// Pre-allocated buffers.
private Vector<String> labels = newVector<String>();
private float[] floatValues;
private float[] outputs;
private String[] outputNames;
private TensorFlowInferenceInterfaceinferenceInterface;
//这儿使用单例模式
private TensorFlowAudioClassifier() {
}
/**
* Initializes a native TensorFlow sessionfor classifying images.
*
* @param assetManager The asset manager to be used to load assets.
* @param modelFilename The filepath of themodel GraphDef protocol buffer.
* @param labelFilename The filepath oflabel file for classes.
* @param inputSize The input size. A square image ofinputSize x inputSize is assumed.
* @param inputName The label of the image input node.
* @param outputName The label of the output node.
* @throws IOException
*/
public static Classifier create(
AssetManager assetManager,
String modelFilename,
String labelFilename,
int inputSize,
String inputName,
String outputName)
throws IOException {
TensorFlowAudioClassifier c = newTensorFlowAudioClassifier();
c.inputName = inputName;
c.outputName = outputName;
// Read the label names into memory.
// TODO(andrewharp): make this handlenon-assets.
//获取label文件,后面可以用来构建bean
String actualFilename =labelFilename.split("file:///android_asset/")[1];
Log.i(TAG, "Reading labels from:" + actualFilename);
BufferedReader br = null;
br = new BufferedReader(new InputStreamReader(assetManager.open(actualFilename)));
String line;
while ((line = br.readLine()) != null){
c.labels.add(line);
}
br.close();
c.inferenceInterface = newTensorFlowInferenceInterface();
if (c.inferenceInterface.initializeTensorFlow(assetManager,modelFilename) != 0) {
throw new RuntimeException("TFinitialization failed");
}
// The shape of the output is [N,NUM_CLASSES], where N is the batch size.
int numClasses =
(int)c.inferenceInterface.graph().operation(outputName).output(0).shape().size(1);
Log.i(TAG, "Read " +c.labels.size() + " labels, output layer size is " + numClasses);
// Ideally, inputSize could have beenretrieved from the shape of the input operation. Alas,
// the placeholder node for input inthe graphdef typically used does not specify a shape, so it
// must be passed in as a parameter.
c.inputSize = inputSize;
// Pre-allocate buffers.
c.outputNames = newString[]{outputName};
c.floatValues = new float[inputSize *inputSize * 1];
c.outputs = new float[numClasses];
return c;
}
@Override
// 识别过程
public List<Recognition>recognizeAudio(String fileName) {
// Log this method so that it can beanalyzed with systrace.
Trace.beginSection("recognizeAudio");
Trace.beginSection("preprocessAudio");
// Preprocess the audio data to normalizedfloat based
// on the provided parameters.
// 将音频文件构建成输入数组,输入数组是一维float数组,所以在Tensorflow中的特征要转变为一维
double[][] data =RFFT.inputData(fileName);
for (int i = 0; i < data.length;++i) {
for (int j = 0 ; j <data[0].length ; ++j) {
floatValues[i * 40 + j] =(float)data[i][j];
}
}
Trace.endSection();
// Copy the input data into TensorFlow.
Trace.beginSection("fillNodeFloat");
//将输入放入InferenceInterface
inferenceInterface.fillNodeFloat(
inputName, newint[]{1,40,40,1}, floatValues);
Trace.endSection();
// Trace.beginSection("fillNodeFloat");
// inferenceInterface.fillNodeFloat(inputName2,new int[]{2,2},floatValues);
// Trace.endSection();
// Run the inference call.
Trace.beginSection("runInference");
// 运行模型
inferenceInterface.runInference(outputNames);
Trace.endSection();
// Copy the output Tensor back into theoutput array.
Trace.beginSection("readNodeFloat");
// 获得输出的confidence到outputs数组里
inferenceInterface.readNodeFloat(outputName, outputs);
Trace.endSection();
// Find the best classifications.
// 用PriorityQueue获取top-3,这儿的Recognition来自于接口Classifier,是一个bean类
PriorityQueue<Recognition> pq =
newPriorityQueue<Recognition>(
3,
newComparator<Recognition>() {
@Override
public intcompare(Recognition lhs, Recognition rhs) {
//Intentionally reversed to put high confidence at the head of the queue.
returnFloat.compare(rhs.getConfidence(), lhs.getConfidence());
}
});
for (int i = 0; i < outputs.length;++i) {
if (outputs[i] > THRESHOLD) {
pq.add(
//构建bean类,参数是label,label name,confidence
new Recognition(
"" +i, labels.size() > i ? labels.get(i) : "unknown", outputs[i]));
}
}
final ArrayList<Recognition>recognitions = new ArrayList<Recognition>();
int recognitionsSize =Math.min(pq.size(), MAX_RESULTS);
if(recognitionsSize == 0){
return null;
}
for (int i = 0; i <recognitionsSize; ++i) {
recognitions.add(pq.poll());
}
Trace.endSection(); //"recognizeAudio"
return recognitions;
}
@Override
public void enableStatLogging(booleandebug) {
inferenceInterface.enableStatLogging(debug);
}
@Override
public String getStatString() {
returninferenceInterface.getStatString();
}
@Override
public void close() {
inferenceInterface.close();
}
}
我们成功地构建了用来识别分类的TensorFlowAudioClassifier,下面展示一下如何使用我们构建的类:
public static TensorFlowAudioClassifierclassifier;
private static final String INPUT_NAME ="input";
private static final String OUTPUT_NAME ="output";
private static final String MODEL_FILE ="file:///android_asset/acoustic.pb";
private static final String LABEL_FILE =
"file:///android_asset/eventLabel.txt";
private static int INPUT_SIZE = 40;
try {
// 获取分类器
classifier =(TensorFlowAudioClassifier) TensorFlowAudioClassifier.create(
getAssets(),
MODEL_FILE,
LABEL_FILE,
INPUT_SIZE,
INPUT_NAME,
OUTPUT_NAME
);
// 识别对应Audio的类别
Recognition result = classifier.recognizeAudio(fileName);
} catch (IOException e) {
e.printStackTrace();
}
总结 现在我们已经成功构建了Android端的TensorFlow深度模型。当然我们的思路不是在Android端训练,而只是使用(Android端训练有点异想天开,毕竟训练在PC端都不是很容易实现的,一些层次深的模型还要借助分布式训练)。
如果我的博客对你有帮助,请记得分享给他人。
以上就是Tensorflow实现Android移动端的模型搭建的全文介绍,希望对您学习Android应用开发有所帮助.