公司想做个web端的识别功能,可网上例子很少,官网可以找到例子:java代码
不过拿来用可能会出现问题,我们web端用的是1.3.0可是官网已经是1.4.0了,代码又不一样了。。。为什么要说又
可以通过下载老版本的源码找到例子代码
还是贴一下1.3.0代码吧
public class LabelImage {
public static void main(String[] args) {
String modelDir = "C:\\sts";
String imageFile = "C:\\sts\\timg.jpg";
byte[] graphDef = readAllBytesOrExit(Paths.get(modelDir, "tensorflow_inception_graph.pb"));
List<String> labels =
readAllLinesOrExit(Paths.get(modelDir, "imagenet_comp_graph_label_strings.txt"));
byte[] imageBytes = readAllBytesOrExit(Paths.get(imageFile));
try (Tensor image = constructAndExecuteGraphToNormalizeImage(imageBytes)) {
float[] labelProbabilities = executeInceptionGraph(graphDef, image);
int bestLabelIdx = maxIndex(labelProbabilities);
System.out.println(
String.format(
"BEST MATCH: %s (%.2f%% likely)",
labels.get(bestLabelIdx), labelProbabilities[bestLabelIdx] * 100f));
}
}
private static Tensor constructAndExecuteGraphToNormalizeImage(byte[] imageBytes) {
try (Graph g = new Graph()) {
GraphBuilder b = new GraphBuilder(g);
final int H = 224;
final int W = 224;
final float mean = 117f;
final float scale = 1f;
final Output input = b.constant("input", imageBytes);
final Output output =
b.div(
b.sub(
b.resizeBilinear(
b.expandDims(
b.cast(b.decodeJpeg(input, 3), DataType.FLOAT),
b.constant("make_batch", 0)),
b.constant("size", new int[] {H, W})),
b.constant("mean", mean)),
b.constant("scale", scale));
try (Session s = new Session(g)) {
return s.runner().fetch(output.op().name()).run().get(0);
}
}
}
private static float[] executeInceptionGraph(byte[] graphDef, Tensor image) {
try (Graph g = new Graph()) {
g.importGraphDef(graphDef);
try (Session s = new Session(g);
Tensor result = s.runner().feed("input", image).fetch("output").run().get(0)) {
final long[] rshape = result.shape();
if (result.numDimensions() != 2 || rshape[0] != 1) {
throw new RuntimeException(
String.format(
"Expected model to produce a [1 N] shaped tensor where N is the number of labels, instead it produced one with shape %s",
Arrays.toString(rshape)));
}
int nlabels = (int) rshape[1];
return result.copyTo(new float[1][nlabels])[0];
}
}
}
private static int maxIndex(float[] probabilities) {
int best = 0;
for (int i = 1; i < probabilities.length; ++i) {
if (probabilities[i] > probabilities[best]) {
best = i;
}
}
return best;
}
private static byte[] readAllBytesOrExit(Path path) {
try {
return Files.readAllBytes(path);
} catch (IOException e) {
System.err.println("Failed to read [" + path + "]: " + e.getMessage());
System.exit(1);
}
return null;
}
private static List<String> readAllLinesOrExit(Path path) {
try {
return Files.readAllLines(path, Charset.forName("UTF-8"));
} catch (IOException e) {
System.err.println("Failed to read [" + path + "]: " + e.getMessage());
System.exit(0);
}
return null;
}
static class GraphBuilder {
GraphBuilder(Graph g) {
this.g = g;
}
Output div(Output x, Output y) {
return binaryOp("Div", x, y);
}
Output sub(Output x, Output y) {
return binaryOp("Sub", x, y);
}
Output resizeBilinear(Output images, Output size) {
return binaryOp("ResizeBilinear", images, size);
}
Output expandDims(Output input, Output dim) {
return binaryOp("ExpandDims", input, dim);
}
Output cast(Output value, DataType dtype) {
return g.opBuilder("Cast", "Cast").addInput(value).setAttr("DstT", dtype).build().output(0);
}
Output decodeJpeg(Output contents, long channels) {
return g.opBuilder("DecodeJpeg", "DecodeJpeg")
.addInput(contents)
.setAttr("channels", channels)
.build()
.output(0);
}
Output constant(String name, Object value) {
try (Tensor t = Tensor.create(value)) {
return g.opBuilder("Const", name)
.setAttr("dtype", t.dataType())
.setAttr("value", t)
.build()
.output(0);
}
}
private Output binaryOp(String type, Output in1, Output in2) {
return g.opBuilder(type, type).addInput(in1).addInput(in2).build().output(0);
}
private Graph g;
}
}
虽然有例子,可是公司的要求高一些,要实现跟踪的ssd模型,没找到例子,搞不明白怎么调用,然后就交给我魔改了,代码完全复制android的tensorflow.jar里的代码
TensorFlowInferenceInterface.java,类名是不是很熟,我基本就是复制粘贴
public class TensorFlowInferenceInterface {
private final String modelName;
private final Graph g;
private final Session sess;
private Session.Runner runner;
private List<String> feedNames = new ArrayList();
private List<Tensor> feedTensors = new ArrayList();
private List<String> fetchNames = new ArrayList();
private List<Tensor> fetchTensors = new ArrayList();
public TensorFlowInferenceInterface(String var2) {
this.modelName = var2;
this.g = new Graph();
this.sess = new Session(this.g);
this.runner = this.sess.runner();
Object var4 = null;
try {
var4 = new FileInputStream(var2);
} catch (IOException var8) {
throw new RuntimeException("Failed to load model from '" + var2 + "'", var8);
}
try {
byte[] var10 = new byte[((InputStream)var4).available()];
int var6 = ((InputStream)var4).read(var10);
if(var6 != var10.length) {
throw new IOException("read error: read only " + var6 + " of the graph, expected to read " + var10.length);
} else {
this.loadGraph(var10, this.g);
((InputStream)var4).close();
}
} catch (IOException var7) {
throw new RuntimeException("Failed to load model from '" + var2 + "'", var7);
}
}
public TensorFlowInferenceInterface(InputStream var1) {
this.modelName = "";
this.g = new Graph();
this.sess = new Session(this.g);
this.runner = this.sess.runner();
try {
int var2 = var1.available() > 16384?var1.available():16384;
ByteArrayOutputStream var3 = new ByteArrayOutputStream(var2);
byte[] var5 = new byte[16384];
int var4;
while((var4 = var1.read(var5, 0, var5.length)) != -1) {
var3.write(var5, 0, var4);
}
byte[] var6 = var3.toByteArray();
this.loadGraph(var6, this.g);
} catch (IOException var7) {
throw new RuntimeException("Failed to load model from the input stream", var7);
}
}
public TensorFlowInferenceInterface(Graph var1) {
this.modelName = "";
this.g = var1;
this.sess = new Session(var1);
this.runner = this.sess.runner();
}
public void run(String[] var1) {
this.closeFetches();
String[] var3 = var1;
int var4 = var1.length;
for(int var5 = 0; var5 < var4; ++var5) {
String var6 = var3[var5];
this.fetchNames.add(var6);
TensorFlowInferenceInterface.TensorId var7 = TensorFlowInferenceInterface.TensorId.parse(var6);
this.runner.fetch(var7.name, var7.outputIndex);
}
try {
this.fetchTensors = this.runner.run();
} catch (RuntimeException var11) {
throw var11;
} finally {
this.closeFeeds();
this.runner = this.sess.runner();
}
}
public Graph graph() {
return this.g;
}
public Operation graphOperation(String var1) {
Operation var2 = this.g.operation(var1);
if(var2 == null) {
throw new RuntimeException("Node '" + var1 + "' does not exist in model '" + this.modelName + "'");
} else {
return var2;
}
}
public void close() {
this.closeFeeds();
this.closeFetches();
this.sess.close();
this.g.close();
}
protected void finalize() throws Throwable {
try {
this.close();
} finally {
super.finalize();
}
}
public void feed(String var1, float[] var2, long... var3) {
this.addFeed(var1, Tensor.create(var3, FloatBuffer.wrap(var2)));
}
public void feed(String var1, int[] var2, long... var3) {
this.addFeed(var1, Tensor.create(var3, IntBuffer.wrap(var2)));
}
public void feed(String var1, long[] var2, long... var3) {
this.addFeed(var1, Tensor.create(var3, LongBuffer.wrap(var2)));
}
public void feed(String var1, double[] var2, long... var3) {
this.addFeed(var1, Tensor.create(var3, DoubleBuffer.wrap(var2)));
}
public void feed(String var1, byte[] var2, long... var3) {
this.addFeed(var1, Tensor.create(DataType.UINT8, var3, ByteBuffer.wrap(var2)));
}
public void feedString(String var1, byte[] var2) {
this.addFeed(var1, Tensor.create(var2));
}
public void feedString(String var1, byte[][] var2) {
this.addFeed(var1, Tensor.create(var2));
}
public void feed(String var1, FloatBuffer var2, long... var3) {
this.addFeed(var1, Tensor.create(var3, var2));
}
public void feed(String var1, IntBuffer var2, long... var3) {
this.addFeed(var1, Tensor.create(var3, var2));
}
public void feed(String var1, LongBuffer var2, long... var3) {
this.addFeed(var1, Tensor.create(var3, var2));
}
public void feed(String var1, DoubleBuffer var2, long... var3) {
this.addFeed(var1, Tensor.create(var3, var2));
}
public void feed(String var1, ByteBuffer var2, long... var3) {
this.addFeed(var1, Tensor.create(DataType.UINT8, var3, var2));
}
public void fetch(String var1, float[] var2) {
this.fetch(var1, FloatBuffer.wrap(var2));
}
public void fetch(String var1, int[] var2) {
this.fetch(var1, IntBuffer.wrap(var2));
}
public void fetch(String var1, long[] var2) {
this.fetch(var1, LongBuffer.wrap(var2));
}
public void fetch(String var1, double[] var2) {
this.fetch(var1, DoubleBuffer.wrap(var2));
}
public void fetch(String var1, byte[] var2) {
this.fetch(var1, ByteBuffer.wrap(var2));
}
public void fetch(String var1, FloatBuffer var2) {
this.getTensor(var1).writeTo(var2);
}
public void fetch(String var1, IntBuffer var2) {
this.getTensor(var1).writeTo(var2);
}
public void fetch(String var1, LongBuffer var2) {
this.getTensor(var1).writeTo(var2);
}
public void fetch(String var1, DoubleBuffer var2) {
this.getTensor(var1).writeTo(var2);
}
public void fetch(String var1, ByteBuffer var2) {
this.getTensor(var1).writeTo(var2);
}
private void loadGraph(byte[] var1, Graph var2) throws IOException {
try {
var2.importGraphDef(var1);
} catch (IllegalArgumentException var7) {
throw new IOException("Not a valid TensorFlow Graph serialization: " + var7.getMessage());
}
}
private void addFeed(String var1, Tensor var2) {
TensorFlowInferenceInterface.TensorId var3 = TensorFlowInferenceInterface.TensorId.parse(var1);
this.runner.feed(var3.name, var3.outputIndex, var2);
this.feedNames.add(var1);
this.feedTensors.add(var2);
}
private Tensor getTensor(String var1) {
int var2 = 0;
for(Iterator var3 = this.fetchNames.iterator(); var3.hasNext(); ++var2) {
String var4 = (String)var3.next();
if(var4.equals(var1)) {
return (Tensor)this.fetchTensors.get(var2);
}
}
throw new RuntimeException("Node '" + var1 + "' was not provided to run(), so it cannot be read");
}
private void closeFeeds() {
Iterator var1 = this.feedTensors.iterator();
while(var1.hasNext()) {
Tensor var2 = (Tensor)var1.next();
var2.close();
}
this.feedTensors.clear();
this.feedNames.clear();
}
private void closeFetches() {
Iterator var1 = this.fetchTensors.iterator();
while(var1.hasNext()) {
Tensor var2 = (Tensor)var1.next();
var2.close();
}
this.fetchTensors.clear();
this.fetchNames.clear();
}
private static class TensorId {
String name;
int outputIndex;
private TensorId() {
}
public static TensorFlowInferenceInterface.TensorId parse(String var0) {
TensorFlowInferenceInterface.TensorId var1 = new TensorFlowInferenceInterface.TensorId();
int var2 = var0.lastIndexOf(58);
if(var2 < 0) {
var1.outputIndex = 0;
var1.name = var0;
return var1;
} else {
try {
var1.outputIndex = Integer.parseInt(var0.substring(var2 + 1));
var1.name = var0.substring(0, var2);
} catch (NumberFormatException var4) {
var1.outputIndex = 0;
var1.name = var0;
}
return var1;
}
}
}
}
Classifier.java
public interface Classifier {
public class Recognition {
private final String id;
private final String title;
private final Float confidence;
private float left,top,right,bottom;
public Recognition(
final String id, final String title, final Float confidence, float left,float top,float right,float bottom) {
this.id = id;
this.title = title;
this.confidence = confidence;
this.left = left;
this.top = top;
this.right = right;
this.bottom = bottom;
}
public String getId() {
return id;
}
public String getTitle() {
return title;
}
public Float getConfidence() {
return confidence;
}
public float getLeft() {
return left;
}
public void setLeft(float left) {
this.left = left;
}
public float getTop() {
return top;
}
public void setTop(float top) {
this.top = top;
}
public float getRight() {
return right;
}
public void setRight(float right) {
this.right = right;
}
public float getBottom() {
return bottom;
}
public void setBottom(float bottom) {
this.bottom = bottom;
}
@Override
public String toString() {
String resultString = "";
if (id != null) {
resultString += "[" + id + "] ";
}
if (title != null) {
resultString += title + " ";
}
if (confidence != null) {
resultString += String.format("(%.1f%%) ", confidence * 100.0f);
}
if (left != 0) {
resultString += left + " ";
}
if (top != 0) {
resultString += top + " ";
}
if (right != 0) {
resultString += right + " ";
}
if (bottom != 0) {
resultString += bottom + " ";
}
return resultString.trim();
}
}
List<Recognition> recognizeImage(int[] byteValues);
void close();
}
TensorFlowObjectDetectionAPIModel.java
public class TensorFlowObjectDetectionAPIModel implements Classifier{
private static final int MAX_RESULTS = 100;
private String inputName;
private int inputSize;
private Vector<String> labels = new Vector<String>();
private byte[] byteValues;
private float[] outputLocations;
private float[] outputScores;
private float[] outputClasses;
private float[] outputNumDetections;
private String[] outputNames;
private TensorFlowInferenceInterface inferenceInterface;
public static Classifier create(
final String modelFilename,
final String labelFilename,
final int inputSize) throws IOException {
final TensorFlowObjectDetectionAPIModel d = new TensorFlowObjectDetectionAPIModel();
InputStream labelsInput = new FileInputStream(labelFilename);
BufferedReader br = null;
br = new BufferedReader(new InputStreamReader(labelsInput));
String line;
while ((line = br.readLine()) != null) {
d.labels.add(line);
}
br.close();
d.inferenceInterface = new TensorFlowInferenceInterface(modelFilename);
final Graph g = d.inferenceInterface.graph();
d.inputName = "image_tensor";
final Operation inputOp = g.operation(d.inputName);
if (inputOp == null) {
throw new RuntimeException("Failed to find input Node '" + d.inputName + "'");
}
d.inputSize = inputSize;
final Operation outputOp1 = g.operation("detection_scores");
if (outputOp1 == null) {
throw new RuntimeException("Failed to find output Node 'detection_scores'");
}
final Operation outputOp2 = g.operation("detection_boxes");
if (outputOp2 == null) {
throw new RuntimeException("Failed to find output Node 'detection_boxes'");
}
final Operation outputOp3 = g.operation("detection_classes");
if (outputOp3 == null) {
throw new RuntimeException("Failed to find output Node 'detection_classes'");
}
d.outputNames = new String[] {"detection_boxes", "detection_scores",
"detection_classes", "num_detections"};
d.byteValues = new byte[d.inputSize * d.inputSize * 3];
d.outputScores = new float[MAX_RESULTS];
d.outputLocations = new float[MAX_RESULTS * 4];
d.outputClasses = new float[MAX_RESULTS];
d.outputNumDetections = new float[1];
return d;
}
private TensorFlowObjectDetectionAPIModel() {}
@Override
public List<Recognition> recognizeImage(int[] intValues) {
for (int i = 0; i < intValues.length; ++i) {
byteValues[i * 3 + 2] = (byte) (intValues[i] & 0xFF);
byteValues[i * 3 + 1] = (byte) ((intValues[i] >> 8) & 0xFF);
byteValues[i * 3 + 0] = (byte) ((intValues[i] >> 16) & 0xFF);
}
inferenceInterface.feed(inputName, byteValues, 1, inputSize, inputSize, 3);
inferenceInterface.run(outputNames);
outputLocations = new float[MAX_RESULTS * 4];
outputScores = new float[MAX_RESULTS];
outputClasses = new float[MAX_RESULTS];
outputNumDetections = new float[1];
inferenceInterface.fetch(outputNames[0], outputLocations);
inferenceInterface.fetch(outputNames[1], outputScores);
inferenceInterface.fetch(outputNames[2], outputClasses);
inferenceInterface.fetch(outputNames[3], outputNumDetections);
final PriorityQueue<Recognition> pq =
new PriorityQueue<Recognition>(1,new Comparator<Recognition>() {
public int compare(final Recognition lhs, final Recognition rhs) {
return Float.compare(rhs.getConfidence(), lhs.getConfidence());
}
});
for (int i = 0; i < outputScores.length; ++i) {
float left = outputLocations[4 * i + 1] * inputSize;
float top = outputLocations[4 * i] * inputSize;
float right = outputLocations[4 * i + 3] * inputSize;
float bottom = outputLocations[4 * i + 2] * inputSize;
pq.add(
new Recognition("" + i, labels.get((int) outputClasses[i]), outputScores[i], left,top,right,bottom));
}
final ArrayList<Recognition> recognitions = new ArrayList<Recognition>();
for (int i = 0; i < Math.min(pq.size(), MAX_RESULTS); ++i) {
recognitions.add(pq.poll());
}
return recognitions;
}
@Override
public void close() {
inferenceInterface.close();
}
}
DetectionImage.java这个是实现类只需要复制实现代码
public class DetectionImage {
public static void main(String[] args) {
int input_size = 300;
Classifier d = null;
try {
d = TensorFlowObjectDetectionAPIModel.create("C:\\sts\\ssd_mobilenet_v1_android_export.pb",
"C:\\sts\\coco_labels_list.txt", input_size);
} catch (IOException e) {
// TODO Auto-generated catch block
e.printStackTrace();
}
if(d != null){
File file = new File("C:\\sts\\person.jpg");
Image img = null;
try {
img = ImageIO.read(file);
} catch (IOException e) {
// TODO Auto-generated catch block
e.printStackTrace();
}
if(img != null){
int width = img.getWidth(null);
int height = img.getHeight(null);
BufferedImage image = new BufferedImage(input_size, input_size, BufferedImage.TYPE_INT_RGB);
image.getGraphics().drawImage(img, 0, 0, input_size, input_size,0,0,width,height,null);
int[] rgbs = new int[input_size * input_size];
image.getRGB(0, 0, input_size, input_size, rgbs, 0, input_size);
List<Recognition> results = d.recognizeImage(rgbs);
for (Recognition result : results) {
System.out.println(result.toString());
}
}
}
}
}
输出结果
[0] person (99.1%) 113.1657 27.434679 183.88785 296.03595
没想到java用起来还是蛮6的,果然android不行了还可以转行java。。。还是程序猿。。。