keras模型h5格式转换tfLite和手机端部署

keras模型h5格式转换tfLite和手机端部署

格式转换

1. checkpoint的model保存为basemodel

nclass = 5990
input = Input(shape=(32, 280, 1), name='the_input')
y_pred= dense_cnn_svd(input, nclass,92)
basemodel = Model(inputs=input, outputs=y_pred)
basemodel.load_weights("densenet_svd/dense_svd.h5")
basemodel.summary()
tf.keras.models.save_model(basemodel,"densenet_svd/dense_svd_new.h5")

2. 转换,目前尝试成功的最低版本为tensorflow1.12.0

converter = tf.contrib.lite.TocoConverter.from_keras_model_file("densenet_svd/dense_svd_new.h5")
#未来版本使lite.TFLiteConverter.from_keras_model_file
tflite_model = converter.convert()
open("converted_model.tflite", "wb").write(tflite_model)

3. 量化,只需要加个flag

converter = tf.contrib.lite.TocoConverter.from_keras_model_file("densenet_svd/dense_svd_new.h5")
converter.post_training_quantize = True ##设置量化工具的flags
tflite_model = converter.convert()
open("converted_model_quant.tflite", "wb").write(tflite_model)

也可以直接从session转,特别是当有自定义层和自定义函数的时候

 sess = K.get_session()
 graph = tf.get_default_graph()
 print(model.input.name, model.output.name)
 input_tensor = graph.get_tensor_by_name(model.input.name)
 output_tensor = graph.get_tensor_by_name(model.output.name)

 converter = tf.contrib.lite.TocoConverter.from_session(sess, [input_tensor], [output_tensor])
 converter.post_training_quantize = True
 tflite_model = converter.convert()
 open("ssd_mobile.tflite", "wb").write(tflite_model)

4. python端调用推断

# 构造推断器,给tensor分配内存
interpreter = tf.contrib.lite.Interpreter(model_path="converted_model.tflite")
interpreter.allocate_tensors()

# 获取输入输出tensor的信息
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()

# 获取书输入的形状
input_shape = input_details[0]['shape']
print(input_shape)
# input_data = np.array(np.random.random_sample(input_shape), dtype=np.float32)

#读取一张图片,resize成input的形状
input_data = cv2.imread("images/2.jpg",cv2.IMREAD_GRAYSCALE)
input_data = input_data.astype("float32")/255.0 - 0.5
input_data = cv2.resize(input_data,(280,32))
input_data = input_data.reshape([1]+list(input_data.shape)+[1])

#灌入输入tensor
print(input_data.shape)
interpreter.set_tensor(input_details[0]['index'], input_data)

#推断
interpreter.invoke()

#获取推断结果
output_data = interpreter.get_tensor(output_details[0]['index'])
print(output_data)
print(decode(output_data))

安卓端部署

转卓端的部署分为以下几个步骤

  1. 导入模型实例,key文件等
  2. 传入bitmap
  3. bitmap转化为bytebuffer
  4. run,然后取出输出数组
  5. 解码

1. 导入模型文件,key文件等

private static MappedByteBuffer loadModelFile(AssetManager assets, String modelFilename)
    throws IOException {
    //loadmodel
    AssetFileDescriptor fileDescriptor = assets.openFd(modelFilename);
    FileInputStream inputStream = new FileInputStream(fileDescriptor.getFileDescriptor());
    FileChannel fileChannel = inputStream.getChannel();
    long startOffset = fileDescriptor.getStartOffset();
    long declaredLength = fileDescriptor.getDeclaredLength();
    return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength);
}
import org.tensorflow.lite.Interpreter;
...
try { // 导入模型,实例化Interpreter
  Interpreter tfLite = new Interpreter(loadModelFile(assetManager, modelFilename));
} catch (Exception e) {
  throw new RuntimeException(e);

try{
    //读取keys
    BufferedReader reader = new BufferedReader(new InputStreamReader(assetManager.open("key.txt")));
    String line;
    while ((line = reader.readLine()) != null) {
        keys = line;
    }
}catch (IOException e){

}

2. 传入bitmap

InputStream inputStream = getAssets().open("2.jpg");
croppedBitmap = BitmapFactory.decodeStream(inputStream);
results = recognizeImage(croppedBitmap);

3. bitmap转化为bytebuffer

private void convertBitmapToFloat1CBuffer(Bitmap bitmap) {
    if (imgDataFloat == null) {
        //给输入分配内存空间
        imgDataFloat =
            ByteBuffer.allocateDirect(
            1*32*280*4);
        imgDataFloat.order(ByteOrder.nativeOrder()); //官方文档称一定要加这一句
    }
    bitmap.getPixels(intValues, 0, bitmap.getWidth(), 0, 0, bitmap.getWidth(), bitmap.getHeight());
    // Convert the image to floating point.
    int pixel = 0;
    for (int i = 0; i < 280; ++i) {
        for (int j = 0; j < 32; ++j) {
            final int val = intValues[pixel++];
            imgDataFloat.putFloat((float) (((val >> 8) & 0xFF)/255.0 - 0.5));
//这里读出来是三通道灰度图,每个通道值一样,任取一个通道,如果是3通道彩图,可能需要转换一下
// float R = (float) (((val >> 16) & 0xFF));
// float G = (float) (((val >> 8) & 0xFF));
// float B = (float) (((val) & 0xFF));
// imgDataFloat.putFloat((float) ((R*0.3+G*0.59+B*0.11)));
        }
    }
} 

4. run,然后取出输出数组

float [][][] labelProb1 = new float[1][35][5990];
tfLite.run(imgDataFloat, labelProb1);

5. 解码

private int[] argmax(float[][][] pred){
    int batchsize = pred.length;
    int len = pred[0].length;
    int classes = pred[0][0].length;
    int argMax[] = new int[len];

    for (int i=0;i<len;++i){
        float MAX = (float) 0.0;
        int maxj = -1;
        float[] arr = pred[0][i];
        for(int j=0;j<classes;++j){
            if(arr[j]>MAX){
                maxj = j;
                MAX = arr[j];
            }
        }
        argMax[i] = maxj;
    }

    return argMax;
}

private String decode(int[] argMax){
    String res = "";
    for (int item = 0; item < argMax.length; item++) {
        if (argMax[item] != 5990 - 1 && (!(item > 0 && argMax[item] == argMax[item - 1]) || (item > 1 && argMax[item] == argMax[item - 2]))) {
            res += keys.charAt(argMax[item]);
        }
    }
    return res;
}

6. 整体调用流程

float [][][] labelProb1 = new float[1][35][5990];//分配内存
int[] argMax;
try {
    convertBitmapToFloat1CBuffer(bitmap); //bitmap转换为buffer
    tfLite.run(imgDataFloat, labelProb1);	//run
    argMax = argmax(labelProb1);	//解码
    String res = decode(argMax); //解码
}catch (Exception e){

}
  • 1
    点赞
  • 10
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值