训练用的网络见下面这篇博客↓↓↓
Tensorflow 直接对验证码进行3通道卷积后识别
对于这上篇博客的网络稍作修改,利于Java调用,下面给出Java桌面级应用以及Android应用的核心代码。
Jave桌面级应用
import tensorflow as tf
import numpy as np
from PIL import Image
import os
import random
train_data_dir = r'C:\Users\HUPENG\Desktop\check_code_crack\check_code\train'
test_data_dir = r''
train_file_name_list = os.listdir(train_data_dir)
def gen_train_data(batch_size=32):
selected_train_file_name_list = random.sample(train_file_name_list, batch_size)
x_data = []
y_data = []
for selected_train_file_name in selected_train_file_name_list:
captcha_image = Image.open(train_data_dir + "/" + selected_train_file_name)
captcha_image_np = np.array(captcha_image)
x_data.append(captcha_image_np)
y_data.append(np.array(list(selected_train_file_name.split('.')[0])).astype(np.int32))
x_data = np.array(x_data)
y_data = np.array(y_data)
return x_data, y_data
X = tf.placeholder(tf.float32, name="input")
# Y = tf.placeholder(tf.int32)
# keep_prob = tf.placeholder(tf.float32)
# y_one_hot = tf.one_hot(Y, 10, 1, 0)
# y_one_hot = tf.cast(y_one_hot, tf.float32)
keep_prob = 1.0
def net(w_alpha=0.01, b_alpha=0.1):
x_reshape = tf.reshape(X, (-1, 218, 82, 3))
w_c1 = tf.Variable(w_alpha * tf.random_normal([3, 3, 3, 16]))
b_c1 = tf.Variable(b_alpha * tf.random_normal([16]))
conv1 = tf.nn.relu(tf.nn.bias_add(tf.nn.conv2d(x_reshape, w_c1, strides=[1, 1, 1, 1], padding='SAME'), b_c1))
conv1 = tf.nn.max_pool(conv1, ksize=[1, 4, 4, 1], strides=[1, 2, 2, 1], padding='SAME')
conv1 = tf.nn.dropout(conv1, keep_prob)
w_c2 = tf.Variable(w_alpha * tf.random_normal([3, 3, 16, 16]))
b_c2 = tf.Variable(b_alpha * tf.random_normal([16]))
conv2 = tf.nn.relu(tf.nn.bias_add(tf.nn.conv2d(conv1, w_c2, strides=[1, 1, 1, 1], padding='SAME'), b_c2))
conv2 = tf.nn.max_pool(conv2, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')
conv2 = tf.nn.dropout(conv2, keep_prob)
w_c3 = tf.Variable(w_alpha * tf.random_normal([3, 3, 16, 16]))
b_c3 = tf.Variable(b_alpha * tf.random_normal([16]))
conv3 = tf.nn.relu(tf.nn.bias_add(tf.nn.conv2d(conv2, w_c3, strides=[1, 1, 1, 1], padding='SAME'), b_c3))
conv3 = tf.nn.max_pool(conv3, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')
conv3 = tf.nn.dropout(conv3, keep_prob)
# Fully connected layer
# 随机生成权重
w_d = tf.Variable(w_alpha * tf.random_normal([28 * 11 * 16, 1024]))
# 随机生成偏置
b_d = tf.Variable(b_alpha * tf.random_normal([1024]))
dense = tf.reshape(conv3, [-1, w_d.get_shape().as_list()[0]])
dense = tf.nn.relu(tf.add(tf.matmul(dense, w_d), b_d))
w_out = tf.Variable(w_alpha * tf.random_normal([1024, 5 * 10]))
b_out = tf.Variable(b_alpha * tf.random_normal([5 * 10]))
out = tf.add(tf.matmul(dense, w_out), b_out)
out = tf.reshape(out, (-1, 5, 10))
out = tf.nn.softmax(out)
out = tf.argmax(out, 2)
out = tf.cast(out, tf.float32, name="output")
return out
cnn = net()
# loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=cnn, labels=y_one_hot))
# optimizer = tf.train.AdamOptimizer(learning_rate=0.001).minimize(loss)
def train():
saver = tf.train.Saver()
with tf.Session() as sess:
step = 0
tf.global_variables_initializer().run()
while True:
x_data, y_data = gen_train_data(64)
x_data = np.reshape(x_data, (-1))
loss_, cnn_, y_one_hot_, optimizer_ = sess.run([loss, cnn, y_one_hot, optimizer],
feed_dict={Y: y_data, X: x_data, keep_prob: 0.75})
print(loss_)
if loss_ < 0.001:
saver.save(sess, "./crack_capcha.model", global_step=step)
print("save model successful!")
break
# cnn_ = sess.run(cnn, feed_dict={Y:y_data, X:x_data})
# print(cnn_.shape)
# break
step += 1
def exportModel():
saver = tf.train.Saver()
with tf.Session() as sess:
# 恢复模型参数
saver.restore(sess, "crack_capcha.model-941")
from tensorflow.python.framework.graph_util import convert_variables_to_constants
output_graph_def = convert_variables_to_constants(sess, sess.graph_def, output_node_names=['output'])
with tf.gfile.FastGFile('model.pb', mode='wb') as f:
f.write(output_graph_def.SerializeToString())
if __name__=='__main__':
# exportModel()
exportModel()
# train()
print("ok")
导出模型文件为model.pb
Java工程的依赖如下:
整个的项目结构如下:
下面写Java端的调用
ECardCaptchaCrack.java
package me.hupeng.sdk.ecardcaptchacrack;
import net.coobird.thumbnailator.Thumbnails;
import org.tensorflow.Graph;
import org.tensorflow.Session;
import org.tensorflow.Tensor;
import java.awt.image.BufferedImage;
import java.awt.image.Raster;
import java.io.BufferedInputStream;
import java.io.File;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;
public class ECardCaptchaCrack {
public static final int MAGIC_NUMBER = 20418991;
static byte[] graphDef = new byte[MAGIC_NUMBER];
static {
try {
BufferedInputStream bis = new BufferedInputStream(ClassLoader.getSystemClassLoader().getResourceAsStream("model.pb"));
int len = bis.read(graphDef, 0, MAGIC_NUMBER);
bis.close();
} catch (IOException e) {
e.printStackTrace();
}
}
public static String crack(String filepath){
File file = new File(filepath);
if (!file.exists()){
return "";
}
String result = "";
try (Graph g = new Graph()) {
g.importGraphDef(graphDef);
BufferedImage im = null;
try {
im = Thumbnails.of(filepath).forceSize(218,82).outputFormat("bmp").asBufferedImage();
} catch (IOException e) {
e.printStackTrace();
return "";
}
Raster raster = im.getData();
float [] temp = new float[raster.getWidth() * raster.getHeight() * raster.getNumBands()];
float [] pixels = raster.getPixels(0,0,raster.getWidth(),raster.getHeight(),temp);
Tensor input = Tensor.create(pixels, Float.class);
try (Session s = new Session(g);
Tensor output = s.runner().feed("input", input).fetch("output").run().get(0).expect(Float.class)) {
// System.out.println(output);
float[][] output2floatArray = new float[1][5];
output.copyTo(output2floatArray);
for(int i=0; i<5; i++){
result += "" + (int)(output2floatArray[0][i]);
}
}
return result;
}
}
private static byte[] readAllBytesOrExit(Path path) {
try {
return Files.readAllBytes(path);
} catch (IOException e) {
System.err.println("Failed to read [" + path + "]: " + e.getMessage());
System.exit(1);
}
return null;
}
}
Main.java
package me.hupeng.sdk.ecardcaptchacrack;
public class Main {
public static void main(String[] args) {
System.out.println(ECardCaptchaCrack.crack("1.jpg"));
}
}
上面的代码预测的是1.jpg这张图
程序输出为:
Android版本:
先把pb文件拷入到asserts目录下面,然后加入以下核心类:
package me.hupeng.app.ecardcaptchaapp;
import android.content.Context;
import android.os.Handler;
import android.os.Looper;
import org.tensorflow.contrib.android.TensorFlowInferenceInterface;
/**
* Created by HUPENG on 2018/6/4.
*/
public class ECardCaptchaCrack {
private static final String MODEL_FILE="file:///android_asset/model.pb";
private Context mContext;
private TensorFlowInferenceInterface inferenceInterface;
private Handler handler = new Handler(Looper.getMainLooper());
public static interface ECardCaptchaCrackListener{
public void callback(String out);
}
private ECardCaptchaCrackListener tensorflowRunnerListener;
public ECardCaptchaCrack(Context context, ECardCaptchaCrackListener tensorflowRunnerListener){
this.mContext = context;
this.tensorflowRunnerListener = tensorflowRunnerListener;
inferenceInterface = new TensorFlowInferenceInterface(this.mContext.getAssets(), MODEL_FILE);
}
public void add(final float[] data){
new Thread(new Runnable() {
@Override
public void run() {
inferenceInterface.feed("input", data, 218 * 82 * 3);
inferenceInterface.run(new String[]{"output"});
final float[] out = new float[5];
inferenceInterface.fetch("output",out);
handler.post(new Runnable() {
@Override
public void run() {
String s = "";
for (int i=0;i<5;i++){
s += (int)(out[i]);
}
tensorflowRunnerListener.callback(s);
}
});
}
}).start();
}
}
调用方式如下:
//加载jpg图片
Bitmap bitmap = getImageFromAssetsFile("22232.jpg");
//读取全部像素点
int [] pixels = new int[218 * 82 ];
bitmap.getPixels(pixels, 0,218,0,0, 218,82);
//转化成RGB
float[] pixels2float = new float[218 * 82 * 3];
for (int i =0; i <218 * 82; i++){
pixels2float[i*3 + 0] = Color.red(pixels[i]);
pixels2float[i*3 + 1] = Color.green(pixels[i]);
pixels2float[i*3 + 2] = Color.blue(pixels[i]);
}
//
ECardCaptchaCrack eCardCaptchaCrack = new ECardCaptchaCrack(this, new ECardCaptchaCrack.ECardCaptchaCrackListener() {
@Override
public void callback(String out) {
Toast.makeText(MainActivity.this, out, Toast.LENGTH_LONG).show();
}
});
//扔进去,等待结果回调
eCardCaptchaCrack.add(pixels2float);