Tensorflow学习笔记(四)模型的保存与加载(三)

Tensorflow学习笔记(四)模型的保存(三)

单个.pb模型的保存与加载以及安卓端的调用

声明: 参考链接这里
之前两种的保存方法保存的模型文件他的模型框架图和权重都是分开的,有时候我们希望他们能够合并在一起方便在其他地方调用比如安卓端。

保存

tf.GraphDef()

GraphDef()中没有包含网络中的Variable值,但是却包含了constant值,所以如果我们能把Variable转换为constant,即可达到使用一个文件同时存储网络架构与权重的目标。

graph_util.convert_variables_to_constants可以把整个session当作常量都保存下来,通过output_node_names参数来指定输出。
如下:

output_graph_def = graph_util.convert_variables_to_constants(sess, sess.graph_def,output_node_names=['output'])

到这里有两种方法把output_graph_def写入.pb文件中两者大同小异我还是都贴出来了,如:

# 第一种
MODEL_SAVE_PATH = "./models/" # 保存模型的路径
tf.train.write_graph(output_graph_def, MODEL_SAVE_PATH, 'my_model.pb', as_text=False)

第二个参数是设定保存的路径,第三个参数是模型文件的名字,第四个参数如果为True就将图形作为ASCII原型写入,这里我们填入False

# 第二种
MODEL_SAVE_PATH = "./models/" # 保存模型的路径
#’wb’中w代表写文件,b代表将数据以二进制方式写入文件。
with tf.gfile.FastGFile(os.path.join(MODEL_SAVE_PATH, 'my_model.pb'), mode='wb') as f:
	f.write(output_graph_def.SerializeToString())

tf.gfile.FastGFile(os.path.join(MODEL_SAVE_PATH, 'my_model.pb'), mode='wb')指定保存文件的路径以及读写方式

f.write(output_graph_def.SerializeToString())将固化的模型写入到文件

完整样例:

import tensorflow as tf # 以下所有代码默认导入
import os
from tensorflow.python.framework import graph_util
# # 保存模型路径
MODEL_SAVE_PATH = "./models/" # 保存模型的路径
MODEL_NAME = "my_model" # 模型命名
# 创建一个变量
one = tf.Variable(2.0)
# 创建一个占位符,在 Tensorflow 中需要定义 placeholder 的 type ,一般为 float32 形式
num = tf.placeholder(tf.float32,name='input')
# 创建一个加法步骤,注意这里并没有直接计算
sum = tf.add(num,one,name='output')
# 初始化变量,如果定义Variable就必须初始化
init = tf.global_variables_initializer()
# 创建会话sess
with tf.Session() as sess:
    sess.run(init)
    output_graph_def = graph_util.convert_variables_to_constants(sess, sess.graph_def, output_node_names=['output'])
    # 第一种
    tf.train.write_graph(output_graph_def, MODEL_SAVE_PATH, 'my_model.pb', as_text=False)
    # 第二种
    # with tf.gfile.FastGFile(os.path.join(MODEL_SAVE_PATH, 'my_model.pb'), mode='wb') as f:
    #     f.write(output_graph_def.SerializeToString())

运行结果如下
在这里插入图片描述

加载

以上两种方法调用加载的方法是通用的。

直接上例子:

import tensorflow as tf # 以下所有代码默认导入
# ###模型调用###
with tf.Session() as sess:
    with open('./models/my_model.pb','rb')as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())
        output = tf.import_graph_def(graph_def,input_map={'input:0':3.}, return_elements=['output:0'])
        print(sess.run(output))

结果:

[5.0]

另外需要说明的一点是,在利用tf.train.write_graph写网络架构的时候,如果令as_text=True了,则在导入网络的时候,需要做一点小修改。

import tensorflow as tf # 以下所有代码默认导入
from google.protobuf import text_format
# ###模型调用###
with tf.Session() as sess:
    with open('./models/my_model.pb','rb')as f:
        graph_def = tf.GraphDef()
        # graph_def.ParseFromString(f.read())
        text_format.Merge(f.read(), graph_def)
        output = tf.import_graph_def(graph_def,input_map={'input:0':3.}, return_elements=['output:0'])
        print(sess.run(output))

安卓端调用

新建一个安卓项目
在这里插入图片描述
改好自己的名字然后一路Next,最后Finish。

然后在build.gradle(app)中的dependencies {}添加

implementation 'org.tensorflow:tensorflow-android:1.11.0'

像这样

在这里插入图片描述
然后把PC端生成的模型放入assets目录!如果没有就新建一个,如

在这里插入图片描述
点击Finish,就生成了assets目录,然后把my_model.pb文件拖入

在这里插入图片描述

然后创建一个新的TF类,如下


import android.content.res.AssetManager;
import android.util.Log;
import org.tensorflow.contrib.android.TensorFlowInferenceInterface;

public class TF {
	//模型存放路径
    private static final String MODEL_FILE = "file:///android_asset/my_model.pb"; 
    //数据的维度
    private static final int HEIGHT = 1;
    private static final int WIDTH = 1;
    //模型中输出变量的名称
    private static final String inputName = "input";
    //用于存储的模型输入数据
    private float[]  inputs = new float[1];
    //模型中输出变量的名称
    private static final String outputName = "output";
    //用于存储模型的输出数据
    private float[]  outputs = new float[1];

    static {
        //加载库文件
        System.loadLibrary("tensorflow_inference");
    }

    TensorFlowInferenceInterface inferenceInterface;

    TF(AssetManager assetManager) {
        //接口定义
        Log.e("模型","TensoFlow模型文件加载");
        inferenceInterface = new TensorFlowInferenceInterface(assetManager,MODEL_FILE);
        Log.e("模型","TensoFlow模型文件加载成功");
    }
    public String Use_model(float num) {
        //为输入数据赋值
        inputs[0] = 3;

        //将数据feed给tensorflow
        inferenceInterface.feed(inputName, inputs, HEIGHT,WIDTH);

        //运行
        String[] outputNames = new String[] {outputName};
        inferenceInterface.run(outputNames);

        //将输出存放到outputs中
        inferenceInterface.fetch(outputName, outputs);

        String str = outputs[0]+"";

        return str;
    }
}

activity_main.xml中给TextView加上IDandroid:id="@+id/text"在这里插入图片描述

主函数中调用,Use_model()方法传入3.0,返回值让TextView显示

public class MainActivity extends AppCompatActivity {


    @Override
    protected void onCreate(Bundle savedInstanceState) {
        super.onCreate(savedInstanceState);
        setContentView(R.layout.activity_main);
        TextView tx = (TextView)findViewById(R.id.text);
        TF tf = new TF(getAssets());
        String num = tf.Use_model(3.0);
        tx.setText(num);
    }
}

最后下载到手机上

在这里插入图片描述

希望这篇文章对您有帮助,感谢阅读!

  • 1
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值