Train TF models in Python and Invoke models in Java

4 篇文章 0 订阅
3 篇文章 0 订阅
  • Plan A
#Train in Python
import tensorflow as tf
# good idea
# https://stackoverflow.com/documentation/tensorflow/10718/save-tensorflow-model-in-python-and-load-with-java#t=201709030336395954421
tf.reset_default_graph()

# DO MODEL STUFF
# Pretrained weighting of 2.0
W = tf.get_variable('w', initializer=tf.constant(2.0), dtype=tf.float32)
# Model input x
x = tf.placeholder(tf.float32, name='x')
# Model output y = W*x
y = tf.multiply(W, x, name='y')

# DO SESSION STUFF
sess = tf.Session()
sess.run(tf.global_variables_initializer())

# SAVE THE MODEL
builder = tf.saved_model.builder.SavedModelBuilder("/tmp/model" )
builder.add_meta_graph_and_variables(
  sess,
  [tf.saved_model.tag_constants.SERVING]
)
builder.save()
//Invoke in Java
import org.tensorflow.SavedModelBundle;
import org.tensorflow.Session;
import org.tensorflow.Tensor;
import org.tensorflow.TensorFlow;

import java.io.IOException;
import java.nio.FloatBuffer;

/**
 * Created by apollo on 17-9-3.
 * https://stackoverflow.com/documentation/tensorflow/10718/save-tensorflow-model-in-python-and-load-with-java#t=201709030336395954421
 */
public class LoadModel {

    public static void main(String[] args) throws IOException {
        // good idea to print the version number, 1.2.0 as of this writing
        System.out.println(TensorFlow.version());
        final int NUM_PREDICTIONS = 1;

        /* load the model Bundle */
        SavedModelBundle b = SavedModelBundle.load("/tmp/model", "serve");

        // create the session from the Bundle
        Session sess = b.session();
        // create an input Tensor, value = 2.0f
        Tensor x = Tensor.create(
                new long[]{NUM_PREDICTIONS},
                FloatBuffer.wrap(new float[]{2.0f})
        );

        // run the model and get the result, 4.0f.
        float[] y = sess.runner()
                .feed("x", x)
                .fetch("y")
                .run()
                .get(0)
                .copyTo(new float[NUM_PREDICTIONS]);

        // print out the result.
        System.out.println(y[0]);
    }
}

==============================================

||||Plan B|||| {only in python , only import}
On the Python side, Tensorflow suggests to use a Saver object to save a model to disk. It creates a .meta file that has the definition and has .data files for the weights. In Python, I use new_saver=tf.train.import_meta_graph(var_filename)
new_saver.restore(sess, model_filename) to read the model from the disk.

||||Plan C|||| {only in python, only save}
tf.train.write_graph(sess.graph_def, “./data”, “aaa.pb”);
this aaa.pb contains graph and variables , not like Plan A(that pb only contain graph)

||||Plan D|||| {only in python, only save , import and perdict have error}
//https://github.com/jiegzhan/multi-class-text-classification-cnn-rnn
saver = tf.train.Saver(tf.all_variables())

error===
saver = tf.train.import_meta_graph(“{}.meta”.format(checkpoint_file[:-5]))
saver.restore(sess, checkpoint_file)
path = saver.save(sess, checkpoint_prefix, global_step=current_step)

||||Plan E||||
https://stackoverflow.com/questions/43598953/loading-sklearn-model-in-java-model-created-with-dnnclassifier-in-python

classifier = learn.DNNClassifier(hidden_units=[10, 20, 5], n_classes=5,feature_columns=feature_columns)
A model_saved.pbtxt file is created.

SavedModelBundle bundle=SavedModelBundle.load(“/java/…/ModelSave”,”serve”);

Reference Website:

http://blog.csdn.net/michael_yt/article/details/74737489

http://blog.csdn.net/lujiandong1/article/details/53385092

https://blog.metaflow.fr/tensorflow-saving-restoring-and-mixing-multiple-models-c4c94d5d7125

http://www.cnblogs.com/nowornever-L/p/6991295.html

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值