Java调用Tensorflow训练出来的模型

23 篇文章 0 订阅
14 篇文章 1 订阅

训练用的网络见下面这篇博客↓↓↓

Tensorflow 直接对验证码进行3通道卷积后识别

对于这上篇博客的网络稍作修改,利于Java调用,下面给出Java桌面级应用以及Android应用的核心代码。

Jave桌面级应用

import tensorflow as tf
import numpy as np
from PIL import Image
import os
import random

train_data_dir = r'C:\Users\HUPENG\Desktop\check_code_crack\check_code\train'
test_data_dir = r''


train_file_name_list = os.listdir(train_data_dir)

def gen_train_data(batch_size=32):
    selected_train_file_name_list = random.sample(train_file_name_list, batch_size)
    x_data = []
    y_data = []
    for selected_train_file_name in selected_train_file_name_list:
        captcha_image = Image.open(train_data_dir + "/" + selected_train_file_name)
        captcha_image_np = np.array(captcha_image)
        x_data.append(captcha_image_np)
        y_data.append(np.array(list(selected_train_file_name.split('.')[0])).astype(np.int32))
    x_data = np.array(x_data)
    y_data = np.array(y_data)
    return x_data, y_data

X = tf.placeholder(tf.float32, name="input")
# Y = tf.placeholder(tf.int32)
# keep_prob = tf.placeholder(tf.float32)
# y_one_hot = tf.one_hot(Y, 10, 1, 0)
# y_one_hot = tf.cast(y_one_hot, tf.float32)
keep_prob = 1.0
def net(w_alpha=0.01, b_alpha=0.1):
    x_reshape = tf.reshape(X, (-1, 218, 82, 3))
    w_c1 = tf.Variable(w_alpha * tf.random_normal([3, 3, 3, 16]))
    b_c1 = tf.Variable(b_alpha * tf.random_normal([16]))
    conv1 = tf.nn.relu(tf.nn.bias_add(tf.nn.conv2d(x_reshape, w_c1, strides=[1, 1, 1, 1], padding='SAME'), b_c1))
    conv1 = tf.nn.max_pool(conv1, ksize=[1, 4, 4, 1], strides=[1, 2, 2, 1], padding='SAME')
    conv1 = tf.nn.dropout(conv1, keep_prob)

    w_c2 = tf.Variable(w_alpha * tf.random_normal([3, 3, 16, 16]))
    b_c2 = tf.Variable(b_alpha * tf.random_normal([16]))
    conv2 = tf.nn.relu(tf.nn.bias_add(tf.nn.conv2d(conv1, w_c2, strides=[1, 1, 1, 1], padding='SAME'), b_c2))
    conv2 = tf.nn.max_pool(conv2, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')
    conv2 = tf.nn.dropout(conv2, keep_prob)

    w_c3 = tf.Variable(w_alpha * tf.random_normal([3, 3, 16, 16]))
    b_c3 = tf.Variable(b_alpha * tf.random_normal([16]))
    conv3 = tf.nn.relu(tf.nn.bias_add(tf.nn.conv2d(conv2, w_c3, strides=[1, 1, 1, 1], padding='SAME'), b_c3))
    conv3 = tf.nn.max_pool(conv3, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')
    conv3 = tf.nn.dropout(conv3, keep_prob)

    # Fully connected layer
    # 随机生成权重
    w_d = tf.Variable(w_alpha * tf.random_normal([28 * 11 * 16, 1024]))
    # 随机生成偏置
    b_d = tf.Variable(b_alpha * tf.random_normal([1024]))
    dense = tf.reshape(conv3, [-1, w_d.get_shape().as_list()[0]])
    dense = tf.nn.relu(tf.add(tf.matmul(dense, w_d), b_d))

    w_out = tf.Variable(w_alpha * tf.random_normal([1024, 5 * 10]))
    b_out = tf.Variable(b_alpha * tf.random_normal([5 * 10]))
    out = tf.add(tf.matmul(dense, w_out), b_out)
    out = tf.reshape(out, (-1, 5, 10))


    out = tf.nn.softmax(out)
    out = tf.argmax(out, 2)
    out = tf.cast(out, tf.float32, name="output")

    return out
cnn = net()
# loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=cnn, labels=y_one_hot))
# optimizer = tf.train.AdamOptimizer(learning_rate=0.001).minimize(loss)

def train():
    saver = tf.train.Saver()
    with tf.Session() as sess:
        step = 0
        tf.global_variables_initializer().run()
        while True:
            x_data, y_data = gen_train_data(64)
            x_data = np.reshape(x_data, (-1))
            loss_, cnn_, y_one_hot_, optimizer_ = sess.run([loss, cnn, y_one_hot, optimizer],
                                                           feed_dict={Y: y_data, X: x_data, keep_prob: 0.75})

            print(loss_)
            if loss_ < 0.001:
                saver.save(sess, "./crack_capcha.model", global_step=step)
                print("save model successful!")
                break

            # cnn_ = sess.run(cnn, feed_dict={Y:y_data, X:x_data})
            # print(cnn_.shape)
            # break
            step += 1
def exportModel():
    saver = tf.train.Saver()
    with tf.Session() as sess:
        # 恢复模型参数
        saver.restore(sess, "crack_capcha.model-941")
        from tensorflow.python.framework.graph_util import convert_variables_to_constants
        output_graph_def = convert_variables_to_constants(sess, sess.graph_def, output_node_names=['output'])
        with tf.gfile.FastGFile('model.pb', mode='wb') as f:
            f.write(output_graph_def.SerializeToString())

if __name__=='__main__':
    # exportModel()
    exportModel()
    # train()
    print("ok")

导出模型文件为model.pb
Java工程的依赖如下:
这里写图片描述
整个的项目结构如下:
这里写图片描述
下面写Java端的调用
ECardCaptchaCrack.java

package me.hupeng.sdk.ecardcaptchacrack;

import net.coobird.thumbnailator.Thumbnails;
import org.tensorflow.Graph;
import org.tensorflow.Session;
import org.tensorflow.Tensor;

import java.awt.image.BufferedImage;
import java.awt.image.Raster;
import java.io.BufferedInputStream;
import java.io.File;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;

public class ECardCaptchaCrack {
    public static final int MAGIC_NUMBER = 20418991;
    static  byte[] graphDef = new byte[MAGIC_NUMBER];

    static {
        try {
            BufferedInputStream bis = new BufferedInputStream(ClassLoader.getSystemClassLoader().getResourceAsStream("model.pb"));
            int len = bis.read(graphDef, 0, MAGIC_NUMBER);
            bis.close();
        } catch (IOException e) {
            e.printStackTrace();
        }
    }

    public static String crack(String filepath){
        File file = new File(filepath);
        if (!file.exists()){
            return "";
        }

        String result = "";
        try (Graph g = new Graph()) {
            g.importGraphDef(graphDef);
            BufferedImage im = null;
            try {
                im = Thumbnails.of(filepath).forceSize(218,82).outputFormat("bmp").asBufferedImage();
            } catch (IOException e) {
                e.printStackTrace();
                return "";
            }
            Raster raster = im.getData();
            float [] temp = new float[raster.getWidth() * raster.getHeight() * raster.getNumBands()];
            float [] pixels  = raster.getPixels(0,0,raster.getWidth(),raster.getHeight(),temp);
            Tensor input = Tensor.create(pixels, Float.class);
            try (Session s = new Session(g);
                 Tensor output = s.runner().feed("input", input).fetch("output").run().get(0).expect(Float.class)) {
//                System.out.println(output);
                float[][] output2floatArray = new float[1][5];
                output.copyTo(output2floatArray);
                for(int i=0; i<5; i++){
                    result += "" + (int)(output2floatArray[0][i]);
                }
            }
            return result;
        }
    }

    private static byte[] readAllBytesOrExit(Path path) {
        try {
            return Files.readAllBytes(path);
        } catch (IOException e) {
            System.err.println("Failed to read [" + path + "]: " + e.getMessage());
            System.exit(1);
        }
        return null;
    }
}

Main.java

package me.hupeng.sdk.ecardcaptchacrack;

public class Main {

    public static void main(String[] args) {
        System.out.println(ECardCaptchaCrack.crack("1.jpg"));
    }
}

上面的代码预测的是1.jpg这张图
这里写图片描述
程序输出为:
这里写图片描述

Android版本:

先把pb文件拷入到asserts目录下面,然后加入以下核心类:

package me.hupeng.app.ecardcaptchaapp;

import android.content.Context;
import android.os.Handler;
import android.os.Looper;

import org.tensorflow.contrib.android.TensorFlowInferenceInterface;

/**
 * Created by HUPENG on 2018/6/4.
 */

public class ECardCaptchaCrack {
    private static final String MODEL_FILE="file:///android_asset/model.pb";
    private Context mContext;
    private TensorFlowInferenceInterface inferenceInterface;
    private Handler handler = new Handler(Looper.getMainLooper());

    public static interface ECardCaptchaCrackListener{
        public void callback(String out);
    }

    private ECardCaptchaCrackListener tensorflowRunnerListener;

    public ECardCaptchaCrack(Context context, ECardCaptchaCrackListener tensorflowRunnerListener){
        this.mContext = context;
        this.tensorflowRunnerListener = tensorflowRunnerListener;
        inferenceInterface = new TensorFlowInferenceInterface(this.mContext.getAssets(), MODEL_FILE);
    }

    public void add(final float[] data){
        new Thread(new Runnable() {
            @Override
            public void run() {
                inferenceInterface.feed("input", data, 218 * 82 * 3);
                inferenceInterface.run(new String[]{"output"});
                final float[] out = new float[5];
                inferenceInterface.fetch("output",out);

                handler.post(new Runnable() {
                    @Override
                    public void run() {
                        String s = "";
                        for (int i=0;i<5;i++){
                            s += (int)(out[i]);
                        }
                        tensorflowRunnerListener.callback(s);
                    }
                });
            }
        }).start();
    }
}

调用方式如下:

//加载jpg图片
        Bitmap bitmap = getImageFromAssetsFile("22232.jpg");
        //读取全部像素点
        int [] pixels  = new int[218 * 82 ];
        bitmap.getPixels(pixels, 0,218,0,0, 218,82);
        //转化成RGB
        float[] pixels2float = new float[218 * 82 * 3];
        for (int i =0; i <218 * 82; i++){
            pixels2float[i*3 + 0] = Color.red(pixels[i]);
            pixels2float[i*3 + 1] = Color.green(pixels[i]);
            pixels2float[i*3 + 2] = Color.blue(pixels[i]);
        }
        //
        ECardCaptchaCrack eCardCaptchaCrack = new ECardCaptchaCrack(this, new ECardCaptchaCrack.ECardCaptchaCrackListener() {
            @Override
            public void callback(String out) {

                Toast.makeText(MainActivity.this, out, Toast.LENGTH_LONG).show();
            }
        });

        //扔进去,等待结果回调
eCardCaptchaCrack.add(pixels2float);
  • 1
    点赞
  • 9
    收藏
    觉得还不错? 一键收藏
  • 6
    评论
评论 6
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值