keras模型h5格式转换tfLite和手机端部署
格式转换
1. checkpoint的model保存为basemodel
nclass = 5990
input = Input(shape=(32, 280, 1), name='the_input')
y_pred= dense_cnn_svd(input, nclass,92)
basemodel = Model(inputs=input, outputs=y_pred)
basemodel.load_weights("densenet_svd/dense_svd.h5")
basemodel.summary()
tf.keras.models.save_model(basemodel,"densenet_svd/dense_svd_new.h5")
2. 转换,目前尝试成功的最低版本为tensorflow1.12.0
converter = tf.contrib.lite.TocoConverter.from_keras_model_file("densenet_svd/dense_svd_new.h5")
#未来版本使lite.TFLiteConverter.from_keras_model_file
tflite_model = converter.convert()
open("converted_model.tflite", "wb").write(tflite_model)
3. 量化,只需要加个flag
converter = tf.contrib.lite.TocoConverter.from_keras_model_file("densenet_svd/dense_svd_new.h5")
converter.post_training_quantize = True ##设置量化工具的flags
tflite_model = converter.convert()
open("converted_model_quant.tflite", "wb").write(tflite_model)
也可以直接从session转,特别是当有自定义层和自定义函数的时候
sess = K.get_session()
graph = tf.get_default_graph()
print(model.input.name, model.output.name)
input_tensor = graph.get_tensor_by_name(model.input.name)
output_tensor = graph.get_tensor_by_name(model.output.name)
converter = tf.contrib.lite.TocoConverter.from_session(sess, [input_tensor], [output_tensor])
converter.post_training_quantize = True
tflite_model = converter.convert()
open("ssd_mobile.tflite", "wb").write(tflite_model)
4. python端调用推断
# 构造推断器,给tensor分配内存
interpreter = tf.contrib.lite.Interpreter(model_path="converted_model.tflite")
interpreter.allocate_tensors()
# 获取输入输出tensor的信息
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
# 获取书输入的形状
input_shape = input_details[0]['shape']
print(input_shape)
# input_data = np.array(np.random.random_sample(input_shape), dtype=np.float32)
#读取一张图片,resize成input的形状
input_data = cv2.imread("images/2.jpg",cv2.IMREAD_GRAYSCALE)
input_data = input_data.astype("float32")/255.0 - 0.5
input_data = cv2.resize(input_data,(280,32))
input_data = input_data.reshape([1]+list(input_data.shape)+[1])
#灌入输入tensor
print(input_data.shape)
interpreter.set_tensor(input_details[0]['index'], input_data)
#推断
interpreter.invoke()
#获取推断结果
output_data = interpreter.get_tensor(output_details[0]['index'])
print(output_data)
print(decode(output_data))
安卓端部署
转卓端的部署分为以下几个步骤
- 导入模型实例,key文件等
- 传入bitmap
- bitmap转化为bytebuffer
- run,然后取出输出数组
- 解码
1. 导入模型文件,key文件等
private static MappedByteBuffer loadModelFile(AssetManager assets, String modelFilename)
throws IOException {
//loadmodel
AssetFileDescriptor fileDescriptor = assets.openFd(modelFilename);
FileInputStream inputStream = new FileInputStream(fileDescriptor.getFileDescriptor());
FileChannel fileChannel = inputStream.getChannel();
long startOffset = fileDescriptor.getStartOffset();
long declaredLength = fileDescriptor.getDeclaredLength();
return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength);
}
import org.tensorflow.lite.Interpreter;
...
try { // 导入模型,实例化Interpreter
Interpreter tfLite = new Interpreter(loadModelFile(assetManager, modelFilename));
} catch (Exception e) {
throw new RuntimeException(e);
try{
//读取keys
BufferedReader reader = new BufferedReader(new InputStreamReader(assetManager.open("key.txt")));
String line;
while ((line = reader.readLine()) != null) {
keys = line;
}
}catch (IOException e){
}
2. 传入bitmap
InputStream inputStream = getAssets().open("2.jpg");
croppedBitmap = BitmapFactory.decodeStream(inputStream);
results = recognizeImage(croppedBitmap);
3. bitmap转化为bytebuffer
private void convertBitmapToFloat1CBuffer(Bitmap bitmap) {
if (imgDataFloat == null) {
//给输入分配内存空间
imgDataFloat =
ByteBuffer.allocateDirect(
1*32*280*4);
imgDataFloat.order(ByteOrder.nativeOrder()); //官方文档称一定要加这一句
}
bitmap.getPixels(intValues, 0, bitmap.getWidth(), 0, 0, bitmap.getWidth(), bitmap.getHeight());
// Convert the image to floating point.
int pixel = 0;
for (int i = 0; i < 280; ++i) {
for (int j = 0; j < 32; ++j) {
final int val = intValues[pixel++];
imgDataFloat.putFloat((float) (((val >> 8) & 0xFF)/255.0 - 0.5));
//这里读出来是三通道灰度图,每个通道值一样,任取一个通道,如果是3通道彩图,可能需要转换一下
// float R = (float) (((val >> 16) & 0xFF));
// float G = (float) (((val >> 8) & 0xFF));
// float B = (float) (((val) & 0xFF));
// imgDataFloat.putFloat((float) ((R*0.3+G*0.59+B*0.11)));
}
}
}
4. run,然后取出输出数组
float [][][] labelProb1 = new float[1][35][5990];
tfLite.run(imgDataFloat, labelProb1);
5. 解码
private int[] argmax(float[][][] pred){
int batchsize = pred.length;
int len = pred[0].length;
int classes = pred[0][0].length;
int argMax[] = new int[len];
for (int i=0;i<len;++i){
float MAX = (float) 0.0;
int maxj = -1;
float[] arr = pred[0][i];
for(int j=0;j<classes;++j){
if(arr[j]>MAX){
maxj = j;
MAX = arr[j];
}
}
argMax[i] = maxj;
}
return argMax;
}
private String decode(int[] argMax){
String res = "";
for (int item = 0; item < argMax.length; item++) {
if (argMax[item] != 5990 - 1 && (!(item > 0 && argMax[item] == argMax[item - 1]) || (item > 1 && argMax[item] == argMax[item - 2]))) {
res += keys.charAt(argMax[item]);
}
}
return res;
}
6. 整体调用流程
float [][][] labelProb1 = new float[1][35][5990];//分配内存
int[] argMax;
try {
convertBitmapToFloat1CBuffer(bitmap); //bitmap转换为buffer
tfLite.run(imgDataFloat, labelProb1); //run
argMax = argmax(labelProb1); //解码
String res = decode(argMax); //解码
}catch (Exception e){
}