java web端调用tensorflow模型

公司想做个web端的识别功能,可网上例子很少,官网可以找到例子:java代码 

不过拿来用可能会出现问题,我们web端用的是1.3.0可是官网已经是1.4.0了,代码又不一样了。。。为什么要说又

可以通过下载老版本的源码找到例子代码

还是贴一下1.3.0代码吧

 

public class LabelImage {


    public static void main(String[] args) {

        String modelDir = "C:\\sts";
        String imageFile = "C:\\sts\\timg.jpg";

        byte[] graphDef = readAllBytesOrExit(Paths.get(modelDir, "tensorflow_inception_graph.pb"));

        List<String> labels =
                readAllLinesOrExit(Paths.get(modelDir, "imagenet_comp_graph_label_strings.txt"));
        byte[] imageBytes = readAllBytesOrExit(Paths.get(imageFile));

        try (Tensor image = constructAndExecuteGraphToNormalizeImage(imageBytes)) {
            float[] labelProbabilities = executeInceptionGraph(graphDef, image);
            int bestLabelIdx = maxIndex(labelProbabilities);
            System.out.println(
                    String.format(
                            "BEST MATCH: %s (%.2f%% likely)",
                            labels.get(bestLabelIdx), labelProbabilities[bestLabelIdx] * 100f));
        }
    }

    private static Tensor constructAndExecuteGraphToNormalizeImage(byte[] imageBytes) {
        try (Graph g = new Graph()) {
            GraphBuilder b = new GraphBuilder(g);

            final int H = 224;
            final int W = 224;
            final float mean = 117f;
            final float scale = 1f;

            final Output input = b.constant("input", imageBytes);
            final Output output =
                    b.div(
                            b.sub(
                                    b.resizeBilinear(
                                            b.expandDims(
                                                    b.cast(b.decodeJpeg(input, 3), DataType.FLOAT),
                                                    b.constant("make_batch", 0)),
                                            b.constant("size", new int[] {H, W})),
                                    b.constant("mean", mean)),
                            b.constant("scale", scale));
            try (Session s = new Session(g)) {
                return s.runner().fetch(output.op().name()).run().get(0);
            }
        }
    }

    private static float[] executeInceptionGraph(byte[] graphDef, Tensor image) {
        try (Graph g = new Graph()) {
            g.importGraphDef(graphDef);
            try (Session s = new Session(g);
                 Tensor result = s.runner().feed("input", image).fetch("output").run().get(0)) {
                final long[] rshape = result.shape();
                if (result.numDimensions() != 2 || rshape[0] != 1) {
                    throw new RuntimeException(
                            String.format(
                                    "Expected model to produce a [1 N] shaped tensor where N is the number of labels, instead it produced one with shape %s",
                                    Arrays.toString(rshape)));
                }
                int nlabels = (int) rshape[1];
                return result.copyTo(new float[1][nlabels])[0];
            }
        }
    }

    private static int maxIndex(float[] probabilities) {
        int best = 0;
        for (int i = 1; i < probabilities.length; ++i) {
            if (probabilities[i] > probabilities[best]) {
                best = i;
            }
        }
        return best;
    }

    private static byte[] readAllBytesOrExit(Path path) {
        try {
            return Files.readAllBytes(path);
        } catch (IOException e) {
            System.err.println("Failed to read [" + path + "]: " + e.getMessage());
            System.exit(1);
        }
        return null;
    }

    private static List<String> readAllLinesOrExit(Path path) {
        try {
            return Files.readAllLines(path, Charset.forName("UTF-8"));
        } catch (IOException e) {
            System.err.println("Failed to read [" + path + "]: " + e.getMessage());
            System.exit(0);
        }
        return null;
    }

    static class GraphBuilder {
        GraphBuilder(Graph g) {
            this.g = g;
        }

        Output div(Output x, Output y) {
            return binaryOp("Div", x, y);
        }

        Output sub(Output x, Output y) {
            return binaryOp("Sub", x, y);
        }

        Output resizeBilinear(Output images, Output size) {
            return binaryOp("ResizeBilinear", images, size);
        }

        Output expandDims(Output input, Output dim) {
            return binaryOp("ExpandDims", input, dim);
        }

        Output cast(Output value, DataType dtype) {
            return g.opBuilder("Cast", "Cast").addInput(value).setAttr("DstT", dtype).build().output(0);
        }

        Output decodeJpeg(Output contents, long channels) {
            return g.opBuilder("DecodeJpeg", "DecodeJpeg")
                    .addInput(contents)
                    .setAttr("channels", channels)
                    .build()
                    .output(0);
        }

        Output constant(String name, Object value) {
            try (Tensor t = Tensor.create(value)) {
                return g.opBuilder("Const", name)
                        .setAttr("dtype", t.dataType())
                        .setAttr("value", t)
                        .build()
                        .output(0);
            }
        }

        private Output binaryOp(String type, Output in1, Output in2) {
            return g.opBuilder(type, type).addInput(in1).addInput(in2).build().output(0);
        }

        private Graph g;
    }
}

 

虽然有例子,可是公司的要求高一些,要实现跟踪的ssd模型,没找到例子,搞不明白怎么调用,然后就交给我魔改了,代码完全复制android的tensorflow.jar里的代码

TensorFlowInferenceInterface.java,类名是不是很熟,我基本就是复制粘贴

 

public class TensorFlowInferenceInterface {
    private final String modelName;
    private final Graph g;
    private final Session sess;
    private Session.Runner runner;
    private List<String> feedNames = new ArrayList();
    private List<Tensor> feedTensors = new ArrayList();
    private List<String> fetchNames = new ArrayList();
    private List<Tensor> fetchTensors = new ArrayList();

    public TensorFlowInferenceInterface(String var2) {
        this.modelName = var2;
        this.g = new Graph();
        this.sess = new Session(this.g);
        this.runner = this.sess.runner();
        Object var4 = null;

        try {
            var4 = new FileInputStream(var2);
        } catch (IOException var8) {
            throw new RuntimeException("Failed to load model from '" + var2 + "'", var8);
        }

        try {
            byte[] var10 = new byte[((InputStream)var4).available()];
            int var6 = ((InputStream)var4).read(var10);
            if(var6 != var10.length) {
                throw new IOException("read error: read only " + var6 + " of the graph, expected to read " + var10.length);
            } else {
                this.loadGraph(var10, this.g);
                ((InputStream)var4).close();
            }
        } catch (IOException var7) {
            throw new RuntimeException("Failed to load model from '" + var2 + "'", var7);
        }
    }
    public TensorFlowInferenceInterface(InputStream var1) {

        this.modelName = "";
        this.g = new Graph();
        this.sess = new Session(this.g);
        this.runner = this.sess.runner();

        try {
            int var2 = var1.available() > 16384?var1.available():16384;
            ByteArrayOutputStream var3 = new ByteArrayOutputStream(var2);
            byte[] var5 = new byte[16384];

            int var4;
            while((var4 = var1.read(var5, 0, var5.length)) != -1) {
                var3.write(var5, 0, var4);
            }

            byte[] var6 = var3.toByteArray();
            this.loadGraph(var6, this.g);

        } catch (IOException var7) {
            throw new RuntimeException("Failed to load model from the input stream", var7);
        }
    }

    public TensorFlowInferenceInterface(Graph var1) {
        this.modelName = "";
        this.g = var1;
        this.sess = new Session(var1);
        this.runner = this.sess.runner();
    }


    public void run(String[] var1) {
        this.closeFetches();
        String[] var3 = var1;
        int var4 = var1.length;

        for(int var5 = 0; var5 < var4; ++var5) {
            String var6 = var3[var5];
            this.fetchNames.add(var6);
            TensorFlowInferenceInterface.TensorId var7 = TensorFlowInferenceInterface.TensorId.parse(var6);
            this.runner.fetch(var7.name, var7.outputIndex);
        }

        try {
            this.fetchTensors = this.runner.run();
        } catch (RuntimeException var11) {
            throw var11;
        } finally {
            this.closeFeeds();
            this.runner = this.sess.runner();
        }

    }

    public Graph graph() {
        return this.g;
    }

    public Operation graphOperation(String var1) {
        Operation var2 = this.g.operation(var1);
        if(var2 == null) {
            throw new RuntimeException("Node '" + var1 + "' does not exist in model '" + this.modelName + "'");
        } else {
            return var2;
        }
    }

    public void close() {
        this.closeFeeds();
        this.closeFetches();
        this.sess.close();
        this.g.close();
    }

    protected void finalize() throws Throwable {
        try {
            this.close();
        } finally {
            super.finalize();
        }

    }

    public void feed(String var1, float[] var2, long... var3) {
        this.addFeed(var1, Tensor.create(var3, FloatBuffer.wrap(var2)));
    }

    public void feed(String var1, int[] var2, long... var3) {
        this.addFeed(var1, Tensor.create(var3, IntBuffer.wrap(var2)));
    }

    public void feed(String var1, long[] var2, long... var3) {
        this.addFeed(var1, Tensor.create(var3, LongBuffer.wrap(var2)));
    }

    public void feed(String var1, double[] var2, long... var3) {
        this.addFeed(var1, Tensor.create(var3, DoubleBuffer.wrap(var2)));
    }

    public void feed(String var1, byte[] var2, long... var3) {
        this.addFeed(var1, Tensor.create(DataType.UINT8, var3, ByteBuffer.wrap(var2)));
    }

    public void feedString(String var1, byte[] var2) {
        this.addFeed(var1, Tensor.create(var2));
    }

    public void feedString(String var1, byte[][] var2) {
        this.addFeed(var1, Tensor.create(var2));
    }

    public void feed(String var1, FloatBuffer var2, long... var3) {
        this.addFeed(var1, Tensor.create(var3, var2));
    }

    public void feed(String var1, IntBuffer var2, long... var3) {
        this.addFeed(var1, Tensor.create(var3, var2));
    }

    public void feed(String var1, LongBuffer var2, long... var3) {
        this.addFeed(var1, Tensor.create(var3, var2));
    }

    public void feed(String var1, DoubleBuffer var2, long... var3) {
        this.addFeed(var1, Tensor.create(var3, var2));
    }

    public void feed(String var1, ByteBuffer var2, long... var3) {
        this.addFeed(var1, Tensor.create(DataType.UINT8, var3, var2));
    }

    public void fetch(String var1, float[] var2) {
        this.fetch(var1, FloatBuffer.wrap(var2));
    }

    public void fetch(String var1, int[] var2) {
        this.fetch(var1, IntBuffer.wrap(var2));
    }

    public void fetch(String var1, long[] var2) {
        this.fetch(var1, LongBuffer.wrap(var2));
    }

    public void fetch(String var1, double[] var2) {
        this.fetch(var1, DoubleBuffer.wrap(var2));
    }

    public void fetch(String var1, byte[] var2) {
        this.fetch(var1, ByteBuffer.wrap(var2));
    }

    public void fetch(String var1, FloatBuffer var2) {
        this.getTensor(var1).writeTo(var2);
    }

    public void fetch(String var1, IntBuffer var2) {
        this.getTensor(var1).writeTo(var2);
    }

    public void fetch(String var1, LongBuffer var2) {
        this.getTensor(var1).writeTo(var2);
    }

    public void fetch(String var1, DoubleBuffer var2) {
        this.getTensor(var1).writeTo(var2);
    }

    public void fetch(String var1, ByteBuffer var2) {
        this.getTensor(var1).writeTo(var2);
    }


    private void loadGraph(byte[] var1, Graph var2) throws IOException {

        try {
            var2.importGraphDef(var1);
        } catch (IllegalArgumentException var7) {
            throw new IOException("Not a valid TensorFlow Graph serialization: " + var7.getMessage());
        }

    }

    private void addFeed(String var1, Tensor var2) {
        TensorFlowInferenceInterface.TensorId var3 = TensorFlowInferenceInterface.TensorId.parse(var1);
        this.runner.feed(var3.name, var3.outputIndex, var2);
        this.feedNames.add(var1);
        this.feedTensors.add(var2);
    }

    private Tensor getTensor(String var1) {
        int var2 = 0;

        for(Iterator var3 = this.fetchNames.iterator(); var3.hasNext(); ++var2) {
            String var4 = (String)var3.next();
            if(var4.equals(var1)) {
                return (Tensor)this.fetchTensors.get(var2);
            }
        }

        throw new RuntimeException("Node '" + var1 + "' was not provided to run(), so it cannot be read");
    }

    private void closeFeeds() {
        Iterator var1 = this.feedTensors.iterator();

        while(var1.hasNext()) {
            Tensor var2 = (Tensor)var1.next();
            var2.close();
        }

        this.feedTensors.clear();
        this.feedNames.clear();
    }

    private void closeFetches() {
        Iterator var1 = this.fetchTensors.iterator();

        while(var1.hasNext()) {
            Tensor var2 = (Tensor)var1.next();
            var2.close();
        }

        this.fetchTensors.clear();
        this.fetchNames.clear();
    }

    private static class TensorId {
        String name;
        int outputIndex;

        private TensorId() {
        }

        public static TensorFlowInferenceInterface.TensorId parse(String var0) {
            TensorFlowInferenceInterface.TensorId var1 = new TensorFlowInferenceInterface.TensorId();
            int var2 = var0.lastIndexOf(58);
            if(var2 < 0) {
                var1.outputIndex = 0;
                var1.name = var0;
                return var1;
            } else {
                try {
                    var1.outputIndex = Integer.parseInt(var0.substring(var2 + 1));
                    var1.name = var0.substring(0, var2);
                } catch (NumberFormatException var4) {
                    var1.outputIndex = 0;
                    var1.name = var0;
                }

                return var1;
            }
        }
    }
}

 

Classifier.java

 

public interface Classifier {

    public class Recognition {

        private final String id;


        private final String title;


        private final Float confidence;
        
        private float left,top,right,bottom;

        public Recognition(
                final String id, final String title, final Float confidence, float left,float top,float right,float bottom) {
            this.id = id;
            this.title = title;
            this.confidence = confidence;
            this.left = left;
            this.top = top;
            this.right = right;
            this.bottom = bottom;
        }

        public String getId() {
            return id;
        }

        public String getTitle() {
            return title;
        }

        public Float getConfidence() {
            return confidence;
        }

        public float getLeft() {
            return left;
        }

        public void setLeft(float left) {
            this.left = left;
        }

        public float getTop() {
            return top;
        }

        public void setTop(float top) {
            this.top = top;
        }

        public float getRight() {
            return right;
        }

        public void setRight(float right) {
            this.right = right;
        }

        public float getBottom() {
            return bottom;
        }

        public void setBottom(float bottom) {
            this.bottom = bottom;
        }

        @Override
        public String toString() {
            String resultString = "";
            if (id != null) {
                resultString += "[" + id + "] ";
            }

            if (title != null) {
                resultString += title + " ";
            }

            if (confidence != null) {
                resultString += String.format("(%.1f%%) ", confidence * 100.0f);
            }

            if (left != 0) {
                resultString += left + " ";
            }
            if (top != 0) {
                resultString += top + " ";
            }
            if (right != 0) {
                resultString += right + " ";
            }
            if (bottom != 0) {
                resultString += bottom + " ";
            }

            return resultString.trim();
        }
    }

    List<Recognition> recognizeImage(int[] byteValues);

    void close();
}

 

TensorFlowObjectDetectionAPIModel.java

 

public class TensorFlowObjectDetectionAPIModel implements Classifier{

    private static final int MAX_RESULTS = 100;


    private String inputName;
    private int inputSize;


    private Vector<String> labels = new Vector<String>();
    private byte[] byteValues;
    private float[] outputLocations;
    private float[] outputScores;
    private float[] outputClasses;
    private float[] outputNumDetections;
    private String[] outputNames;

    private TensorFlowInferenceInterface inferenceInterface;


    public static Classifier create(
            final String modelFilename,
            final String labelFilename,
            final int inputSize) throws IOException {
        final TensorFlowObjectDetectionAPIModel d = new TensorFlowObjectDetectionAPIModel();

        InputStream labelsInput = new FileInputStream(labelFilename);
        BufferedReader br = null;
        br = new BufferedReader(new InputStreamReader(labelsInput));
        String line;
        while ((line = br.readLine()) != null) {
            d.labels.add(line);
        }
        br.close();


        d.inferenceInterface = new TensorFlowInferenceInterface(modelFilename);

        final Graph g = d.inferenceInterface.graph();

        d.inputName = "image_tensor";

        final Operation inputOp = g.operation(d.inputName);
        if (inputOp == null) {
            throw new RuntimeException("Failed to find input Node '" + d.inputName + "'");
        }
        d.inputSize = inputSize;

        final Operation outputOp1 = g.operation("detection_scores");
        if (outputOp1 == null) {
            throw new RuntimeException("Failed to find output Node 'detection_scores'");
        }
        final Operation outputOp2 = g.operation("detection_boxes");
        if (outputOp2 == null) {
            throw new RuntimeException("Failed to find output Node 'detection_boxes'");
        }
        final Operation outputOp3 = g.operation("detection_classes");
        if (outputOp3 == null) {
            throw new RuntimeException("Failed to find output Node 'detection_classes'");
        }

        d.outputNames = new String[] {"detection_boxes", "detection_scores",
                "detection_classes", "num_detections"};
        d.byteValues = new byte[d.inputSize * d.inputSize * 3];
        d.outputScores = new float[MAX_RESULTS];
        d.outputLocations = new float[MAX_RESULTS * 4];
        d.outputClasses = new float[MAX_RESULTS];
        d.outputNumDetections = new float[1];
        return d;
    }

    private TensorFlowObjectDetectionAPIModel() {}


    @Override
    public List<Recognition> recognizeImage(int[] intValues) {

        for (int i = 0; i < intValues.length; ++i) {
            byteValues[i * 3 + 2] = (byte) (intValues[i] & 0xFF);
            byteValues[i * 3 + 1] = (byte) ((intValues[i] >> 8) & 0xFF);
            byteValues[i * 3 + 0] = (byte) ((intValues[i] >> 16) & 0xFF);
        }

        inferenceInterface.feed(inputName, byteValues, 1, inputSize, inputSize, 3);


        inferenceInterface.run(outputNames);


        outputLocations = new float[MAX_RESULTS * 4];
        outputScores = new float[MAX_RESULTS];
        outputClasses = new float[MAX_RESULTS];
        outputNumDetections = new float[1];
        inferenceInterface.fetch(outputNames[0], outputLocations);
        inferenceInterface.fetch(outputNames[1], outputScores);
        inferenceInterface.fetch(outputNames[2], outputClasses);
        inferenceInterface.fetch(outputNames[3], outputNumDetections);



        final PriorityQueue<Recognition> pq =
                new PriorityQueue<Recognition>(1,new Comparator<Recognition>() {
                    public int compare(final Recognition lhs, final Recognition rhs) {

                        return Float.compare(rhs.getConfidence(), lhs.getConfidence());
                    }
                });

        for (int i = 0; i < outputScores.length; ++i) {

            float left = outputLocations[4 * i + 1] * inputSize;
            float top = outputLocations[4 * i] * inputSize;
            float right = outputLocations[4 * i + 3] * inputSize;
            float bottom = outputLocations[4 * i + 2] * inputSize;

            pq.add(
                    new Recognition("" + i, labels.get((int) outputClasses[i]), outputScores[i], left,top,right,bottom));
        }

        final ArrayList<Recognition> recognitions = new ArrayList<Recognition>();
        for (int i = 0; i < Math.min(pq.size(), MAX_RESULTS); ++i) {
            recognitions.add(pq.poll());
        }
        return recognitions;
    }

    @Override
    public void close() {
        inferenceInterface.close();
    }
}

 

DetectionImage.java这个是实现类只需要复制实现代码

 

public class DetectionImage {

    public static void main(String[] args) {

        int input_size = 300;

        Classifier d = null;

        try {
            d = TensorFlowObjectDetectionAPIModel.create("C:\\sts\\ssd_mobilenet_v1_android_export.pb",
                    "C:\\sts\\coco_labels_list.txt", input_size);
        } catch (IOException e) {
            // TODO Auto-generated catch block
            e.printStackTrace();
        }
        if(d != null){
            File file = new File("C:\\sts\\person.jpg");
            Image img = null;
            try {
                img = ImageIO.read(file);
            } catch (IOException e) {
                // TODO Auto-generated catch block
                e.printStackTrace();
            }
            if(img != null){
                int width = img.getWidth(null);
                int height = img.getHeight(null);
                BufferedImage image = new BufferedImage(input_size, input_size, BufferedImage.TYPE_INT_RGB);
                image.getGraphics().drawImage(img, 0, 0, input_size, input_size,0,0,width,height,null);

                int[] rgbs = new int[input_size * input_size];
                image.getRGB(0, 0, input_size, input_size, rgbs, 0, input_size);

                List<Recognition> results = d.recognizeImage(rgbs);


                for (Recognition result : results) {
                    System.out.println(result.toString());
                }
            }
        }
    }
}

 

输出结果

[0] person (99.1%) 113.1657 27.434679 183.88785 296.03595

没想到java用起来还是蛮6的,果然android不行了还可以转行java。。。还是程序猿。。。

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

  • 1
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 3
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值