Tensorflow学习笔记(四)模型的保存(三)
单个.pb模型的保存与加载以及安卓端的调用
声明: 参考链接这里
之前两种的保存方法保存的模型文件他的模型框架图和权重都是分开的,有时候我们希望他们能够合并在一起方便在其他地方调用比如安卓端。
保存
tf.GraphDef()
GraphDef()
中没有包含网络中的Variable值,但是却包含了constant值,所以如果我们能把Variable转换为constant,即可达到使用一个文件同时存储网络架构与权重的目标。
graph_util.convert_variables_to_constants
可以把整个session当作常量都保存下来,通过output_node_names
参数来指定输出。
如下:
output_graph_def = graph_util.convert_variables_to_constants(sess, sess.graph_def,output_node_names=['output'])
到这里有两种方法把output_graph_def
写入.pb
文件中两者大同小异我还是都贴出来了,如:
# 第一种
MODEL_SAVE_PATH = "./models/" # 保存模型的路径
tf.train.write_graph(output_graph_def, MODEL_SAVE_PATH, 'my_model.pb', as_text=False)
第二个参数是设定保存的路径,第三个参数是模型文件的名字,第四个参数如果为True
就将图形作为ASCII原型写入,这里我们填入False
。
# 第二种
MODEL_SAVE_PATH = "./models/" # 保存模型的路径
#’wb’中w代表写文件,b代表将数据以二进制方式写入文件。
with tf.gfile.FastGFile(os.path.join(MODEL_SAVE_PATH, 'my_model.pb'), mode='wb') as f:
f.write(output_graph_def.SerializeToString())
tf.gfile.FastGFile(os.path.join(MODEL_SAVE_PATH, 'my_model.pb'), mode='wb')
指定保存文件的路径以及读写方式
f.write(output_graph_def.SerializeToString())
将固化的模型写入到文件
完整样例:
import tensorflow as tf # 以下所有代码默认导入
import os
from tensorflow.python.framework import graph_util
# # 保存模型路径
MODEL_SAVE_PATH = "./models/" # 保存模型的路径
MODEL_NAME = "my_model" # 模型命名
# 创建一个变量
one = tf.Variable(2.0)
# 创建一个占位符,在 Tensorflow 中需要定义 placeholder 的 type ,一般为 float32 形式
num = tf.placeholder(tf.float32,name='input')
# 创建一个加法步骤,注意这里并没有直接计算
sum = tf.add(num,one,name='output')
# 初始化变量,如果定义Variable就必须初始化
init = tf.global_variables_initializer()
# 创建会话sess
with tf.Session() as sess:
sess.run(init)
output_graph_def = graph_util.convert_variables_to_constants(sess, sess.graph_def, output_node_names=['output'])
# 第一种
tf.train.write_graph(output_graph_def, MODEL_SAVE_PATH, 'my_model.pb', as_text=False)
# 第二种
# with tf.gfile.FastGFile(os.path.join(MODEL_SAVE_PATH, 'my_model.pb'), mode='wb') as f:
# f.write(output_graph_def.SerializeToString())
运行结果如下
加载
以上两种方法调用加载的方法是通用的。
直接上例子:
import tensorflow as tf # 以下所有代码默认导入
# ###模型调用###
with tf.Session() as sess:
with open('./models/my_model.pb','rb')as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
output = tf.import_graph_def(graph_def,input_map={'input:0':3.}, return_elements=['output:0'])
print(sess.run(output))
结果:
[5.0]
另外需要说明的一点是,在利用tf.train.write_graph
写网络架构的时候,如果令as_text=True
了,则在导入网络的时候,需要做一点小修改。
import tensorflow as tf # 以下所有代码默认导入
from google.protobuf import text_format
# ###模型调用###
with tf.Session() as sess:
with open('./models/my_model.pb','rb')as f:
graph_def = tf.GraphDef()
# graph_def.ParseFromString(f.read())
text_format.Merge(f.read(), graph_def)
output = tf.import_graph_def(graph_def,input_map={'input:0':3.}, return_elements=['output:0'])
print(sess.run(output))
安卓端调用
新建一个安卓项目
改好自己的名字然后一路Next,最后Finish。
然后在build.gradle(app)中的dependencies {}
添加
implementation 'org.tensorflow:tensorflow-android:1.11.0'
像这样
然后把PC端生成的模型放入assets目录!如果没有就新建一个,如
点击Finish,就生成了assets目录,然后把my_model.pb
文件拖入
然后创建一个新的TF类,如下
import android.content.res.AssetManager;
import android.util.Log;
import org.tensorflow.contrib.android.TensorFlowInferenceInterface;
public class TF {
//模型存放路径
private static final String MODEL_FILE = "file:///android_asset/my_model.pb";
//数据的维度
private static final int HEIGHT = 1;
private static final int WIDTH = 1;
//模型中输出变量的名称
private static final String inputName = "input";
//用于存储的模型输入数据
private float[] inputs = new float[1];
//模型中输出变量的名称
private static final String outputName = "output";
//用于存储模型的输出数据
private float[] outputs = new float[1];
static {
//加载库文件
System.loadLibrary("tensorflow_inference");
}
TensorFlowInferenceInterface inferenceInterface;
TF(AssetManager assetManager) {
//接口定义
Log.e("模型","TensoFlow模型文件加载");
inferenceInterface = new TensorFlowInferenceInterface(assetManager,MODEL_FILE);
Log.e("模型","TensoFlow模型文件加载成功");
}
public String Use_model(float num) {
//为输入数据赋值
inputs[0] = 3;
//将数据feed给tensorflow
inferenceInterface.feed(inputName, inputs, HEIGHT,WIDTH);
//运行
String[] outputNames = new String[] {outputName};
inferenceInterface.run(outputNames);
//将输出存放到outputs中
inferenceInterface.fetch(outputName, outputs);
String str = outputs[0]+"";
return str;
}
}
activity_main.xml
中给TextView加上IDandroid:id="@+id/text"
主函数中调用,Use_model()
方法传入3.0,返回值让TextView显示
public class MainActivity extends AppCompatActivity {
@Override
protected void onCreate(Bundle savedInstanceState) {
super.onCreate(savedInstanceState);
setContentView(R.layout.activity_main);
TextView tx = (TextView)findViewById(R.id.text);
TF tf = new TF(getAssets());
String num = tf.Use_model(3.0);
tx.setText(num);
}
}
最后下载到手机上
希望这篇文章对您有帮助,感谢阅读!