Tomcat部署图片分类模型系列(二)

前言

之前已经完成了java端环境的部署:Tomcat部署图片分类模型系列(一),接下来将完成模型的读取。

模型的读取以及排坑

  • 模型转化为pb格式

将模型部署在java端,需要将模型转化为.pb格式,这里笔者是使用keras训练的模型,下面是h5转pb的代码(以下是python代码,需要在python的编译器下运行)

#*-coding:utf-8-*

"""
将keras的.h5的模型文件,转换成TensorFlow的pb文件
"""
# ==========================================================

from keras.models import load_model
import tensorflow as tf
import os.path as osp
import os
from keras import backend
#from keras.models import Sequential

def h5_to_pb(h5_model, output_dir, model_name, out_prefix="output_", log_tensorboard=True):
    """.h5模型文件转换成pb模型文件
    Argument:
        h5_model: str
            .h5模型文件
        output_dir: str
            pb模型文件保存路径
        model_name: str
            pb模型文件名称
        out_prefix: str
            根据训练,需要修改
        log_tensorboard: bool
            是否生成日志文件
    Return:
        pb模型文件
    """
    if os.path.exists(output_dir) == False:
        os.mkdir(output_dir)
    out_nodes = []
    for i in range(len(h5_model.outputs)):
        out_nodes.append(out_prefix + str(i + 1))
        tf.identity(h5_model.outputs[i], out_prefix + str(i + 1))
    sess = backend.get_session()

    from tensorflow.python.framework import graph_util, graph_io
    # 写入pb模型文件
    init_graph = sess.graph.as_graph_def()
    main_graph = graph_util.convert_variables_to_constants(sess, init_graph, out_nodes)
    graph_io.write_graph(main_graph, output_dir, name=model_name, as_text=False)
    # 输出日志文件
    if log_tensorboard:
        from tensorflow.python.tools import import_pb_to_tensorboard
        import_pb_to_tensorboard.import_to_tensorboard(os.path.join(output_dir, model_name), output_dir)


if __name__ == '__main__':
    #  .h模型文件路径参数
    input_path = 'D:/CSP'//h5文件所在的文件夹的路径
    weight_file = 'xingren.h5'
    weight_file_path = os.path.join(input_path, weight_file)
    output_graph_name = weight_file[:-3] + '.pb'

    #  pb模型文件输出输出路径
    output_dir = osp.join(os.getcwd(),"trans_model")
    #model.save(xingren.h5)
    #  加载模型
    #h5_model = Sequential()
    h5_model = load_model(weight_file_path)
    #h5_model.save(weight_file_path)
    #h5_model.save('xingren.h5')
    h5_to_pb(h5_model, output_dir=output_dir, model_name=output_graph_name)
    print ('Finished')

  • java端读取pb文件
    读取文件的代码,笔者参考了LabelImage.java,代码基本相似。现在上代码。
package test;

import java.io.File;
import java.io.IOException;
import java.nio.charset.Charset;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.Arrays;
import java.util.List;

import org.apache.commons.io.FileUtils;
import org.tensorflow.DataType;
import org.tensorflow.Graph;
import org.tensorflow.Operation;
import org.tensorflow.Output;
import org.tensorflow.SavedModelBundle;
import org.tensorflow.Session;
import org.tensorflow.Shape;
import org.tensorflow.Tensor;
import org.tensorflow.TensorFlow;
import org.tensorflow.TensorFlow;
import org.tensorflow.types.UInt8;



public class tensorflow {
	
	 private static Tensor<Float> constructAndExecuteGraphToNormalizeImage(byte[] imageBytes) {
		    try (Graph g = new Graph()) {
		      GraphBuilder b = new GraphBuilder(g);
		      // Some constants specific to the pre-trained model at:
		      // https://storage.googleapis.com/download.tensorflow.org/models/inception5h.zip
		      //
		      // - The model was trained with images scaled to 224x224 pixels.
		      // - The colors, represented as R, G, B in 1-byte each were converted to
		      //   float using (value - Mean)/Scale.
		      final int H = 224;
		      final int W = 224;
		      //以下是两个可调参数,能够影响模型的输出
		      final float mean = 100f;
		      final float scale = 128f;
		      // Since the graph is being constructed once per execution here, we can use a constant for the
		      // input image. If the graph were to be re-used for multiple input images, a placeholder would
		      // have been more appropriate.
		      final Output<String> input = b.constant("input", imageBytes);
		      final Output<Float> output =
		          b.div(
		              b.sub(
		                  b.resizeBilinear(
		                      b.expandDims(
		                          b.cast(b.decodeJpeg(input, 3), Float.class),
		                          b.constant("make_batch", 0)),
		                      b.constant("size", new int[] {H, W})),
		                  b.constant("mean", mean)),
		              b.constant("scale", scale));
		      try (Session s = new Session(g)) {
		        return s.runner().fetch(output.op().name()).run().get(0).expect(Float.class);
		      }
		    }
		  }
		  private static float[] executeInceptionGraph(byte[] graphDef, Tensor<Float> image) {
		    try (Graph g = new Graph()) {
		      g.importGraphDef(graphDef);
		      try (Session s = new Session(g);
		          Tensor<Float> result =
		              (Tensor<Float>) s.runner().feed("input_1", image).fetch("output_1").run().get(0).expect(Float.class)) {
		        final long[] rshape = result.shape();
		        if (result.numDimensions() != 2 || rshape[0] != 1) {
		          throw new RuntimeException(
		              String.format(
		                  "Expected model to produce a [1 N] shaped tensor where N is the number of labels, instead it produced one with shape %s",
		                  Arrays.toString(rshape)));
		        }
		        int nlabels = (int) rshape[1];
		        return result.copyTo(new float[1][nlabels])[0];
		      }
		    }
		  }
		  private static int maxIndex(float[] probabilities) {
		    int best = 0;
		    for (int i = 1; i < probabilities.length; ++i) {
		      if (probabilities[i] > probabilities[best]) {
		        best = i;
		      }
		    }
		    return best;
		  }
		  private static byte[] readAllBytesOrExit(Path path) {
		    try {
		      return Files.readAllBytes(path);
		    } catch (IOException e) {
		      System.err.println("Failed to read [" + path + "]: " + e.getMessage());
		      System.exit(1);
		    }
		    return null;
		  }
		  private static List<String> readAllLinesOrExit(Path path) {
		    try {
		      return Files.readAllLines(path, Charset.forName("UTF-8"));
		    } catch (IOException e) {
		      System.err.println("Failed to read [" + path + "]: " + e.getMessage());
		      System.exit(0);
		    }
		    return null;
		  }


	public static void main(String[] args) {
		String modelDir = "E:\\Desktop\\";//模型所在的文件夹
	    String imageFile = "E:\\Desktop\\dog.jpg";//图片的地址
	    byte[] graphDef = readAllBytesOrExit(Paths.get(modelDir, "class42.pb"));
	    List<String> labels =
	        readAllLinesOrExit(Paths.get(modelDir, "class42.txt"));//读取相应的标签,也可以自己手动添加。
	    //将图片转化为byte[]流
	    byte[] imageBytes = readAllBytesOrExit(Paths.get(imageFile));
	    Long time=System.currentTimeMillis();
	    //将图片转化为tensor,同时将图片转化为规定的大小
	    try (Tensor<Float> image = constructAndExecuteGraphToNormalizeImage(imageBytes)) {
	      float[] labelProbabilities = executeInceptionGraph(graphDef, image);
	      int bestLabelIdx = maxIndex(labelProbabilities);
	      for(int i=0;i<42;i++) {
	    	  System.out.print(labelProbabilities[i]+" ");
	      }
	      System.out.println();
	      System.out.println(
	         String.format("该图片可能是: %s (%.2f%% likely)",
	              labels.get(bestLabelIdx),
	             labelProbabilities[bestLabelIdx] * 100f));
	    }
	   }
	
	// In the fullness of time, equivalents of the methods of this class should be auto-generated from
	  // the OpDefs linked into libtensorflow_jni.so. That would match what is done in other languages
	  // like Python, C++ and Go.
	  static class GraphBuilder {
	    GraphBuilder(Graph g) {
	      this.g = g;
	    }

	    Output<Float> div(Output<Float> x, Output<Float> y) {
	      return binaryOp("Div", x, y);
	    }

	    <T> Output<T> sub(Output<T> x, Output<T> y) {
	      return binaryOp("Sub", x, y);
	    }

	    <T> Output<Float> resizeBilinear(Output<T> images, Output<Integer> size) {
	      return binaryOp3("ResizeBilinear", images, size);
	    }

	    <T> Output<T> expandDims(Output<T> input, Output<Integer> dim) {
	      return binaryOp3("ExpandDims", input, dim);
	    }

	    <T, U> Output<U> cast(Output<T> value, Class<U> type) {
	      DataType dtype = DataType.fromClass(type);
	      return g.opBuilder("Cast", "Cast")
	          .addInput(value)
	          .setAttr("DstT", dtype)
	          .build()
	          .<U>output(0);
	    }

	    Output<UInt8> decodeJpeg(Output<String> contents, long channels) {
	      return g.opBuilder("DecodeJpeg", "DecodeJpeg")
	          .addInput(contents)
	          .setAttr("channels", channels)
	          .build()
	          .<UInt8>output(0);
	    }

	    <T> Output<T> constant(String name, Object value, Class<T> type) {
	      try (Tensor<T> t = Tensor.<T>create(value, type)) {
	        return g.opBuilder("Const", name)
	            .setAttr("dtype", DataType.fromClass(type))
	            .setAttr("value", t)
	            .build()
	            .<T>output(0);
	      }
	    }
	    Output<String> constant(String name, byte[] value) {
	      return this.constant(name, value, String.class);
	    }

	    Output<Integer> constant(String name, int value) {
	      return this.constant(name, value, Integer.class);
	    }

	    Output<Integer> constant(String name, int[] value) {
	      return this.constant(name, value, Integer.class);
	    }

	    Output<Float> constant(String name, float value) {
	      return this.constant(name, value, Float.class);
	    }

	    private <T> Output<T> binaryOp(String type, Output<T> in1, Output<T> in2) {
	      return g.opBuilder(type, type).addInput(in1).addInput(in2).build().<T>output(0);
	    }

	    private <T, U, V> Output<T> binaryOp3(String type, Output<U> in1, Output<V> in2) {
	      return g.opBuilder(type, type).addInput(in1).addInput(in2).build().<T>output(0);
	    }
	    private Graph g;
	  }

}

模型部署排坑

  1. Expected model to produce a [1 N] shaped tensor where N is the number of labels, instead it produced
    在运行时,编译器报出如下错误,问题出在pb文件中,具体解决问题可以参考:Expected model to produce a [1 N] shaped tensor where N is the number of labels, instead it produced

  2. Tensor result =
    (Tensor) s.runner().feed(“input_1”, image).fetch(“output_1”).run().get(0).expect(Float.class))

    这句代码要根据自己的模型手动修改,其中.feed()中的“input_1”是笔者模型的输入节点的名称,.fetch(“output_1”)中的“output_1”是输出节点的名称,各位要根据自己模型的节点的名称填写。下面笔者提供查找pb文件中,节点名称的代码(此为python代码)

import tensorflow as tf

#第一个参数输入.pb的文件地址
with tf.gfile.GFile('E:\\Desktop\\class42.pb', "rb") as f:  #读取模型数据
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read()) #得到模型中的计算图和数据
with tf.Graph().as_default() as graph:  # 这里的Graph()要有括号,不然会报TypeError
    tf.import_graph_def(graph_def, name="")  #导入模型中的图到现在这个新的计算图中,不指定名字的话默认是 import
    for op in graph.get_operations():  # 打印出图中的节点信息
        print("{}     {}".format(op.name, op.values()))

到此我遇到的坑就基本结束了,希望大家多多指正。
接下来就是使用Servlet接受图片并预测了。

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值