YOLOv4部署-keras权重转tflite (h5 weights -> pb -> tflite)

该博客详细介绍了如何在TensorFlow 2.2.0环境下,将预训练的Keras h5模型转换为.pb文件,然后进一步转化为.tflite轻量化模型。首先,通过修改Keras的导入方式以适配TensorFlow 2.x,接着加载h5模型并保存为.pb。最后,使用TFLiteConverter将.pb模型转换为.tflite,以便于在移动端或其他资源有限的环境中部署。
摘要由CSDN通过智能技术生成

一、相关环境介绍

1、网络训练的环境

  • tensorflow-gpu==1.4.0
  • Keras==2.1.5
  • python==3.6.2

2、h5->pb->tflite的环境

  • tensorflow==2.2.0
  • python==3.6.2

二、转换过程

1、加载h5权重,并转换为pb

  • 修改yolov4网络结构中keras的导包命令,tensorflow=2.2.0中自带的有keras

    # 将所有的
    from keras import *
    from keras.* import *
    # 修改过为:
    from tensorflow.keras import *
    from tensorflow.keras.* import *
    
  • 加载权重和网络结构

    from tensorflow.keras.models import Model
    import tensorflow as tf
    
    class YOLO(object):
        _defaults = {
            ...
        }
    
        @classmethod
        def get_defaults(cls, n):
            if n in cls._defaults:
                return cls._defaults[n]
            else:
                return "Unrecognized attribute name '" + n + "'"
    
        #---------------------------------------------------#
        #   初始化yolo
        #---------------------------------------------------#
        def __init__(self, **kwargs):
            self.__dict__.update(self._defaults)
            for name, value in kwargs.items():
                setattr(self, name, value)
                self._defaults[name] = value 
            self.class_names, self.num_classes = get_classes(self.classes_path)
            self.anchors, self.num_anchors     = get_anchors(self.anchors_path)
            self.generate()
    
        #---------------------------------------------------#
        #   载入模型
        #---------------------------------------------------#
        def generate(self):
            model_path = os.path.expanduser(self.model_path)
            assert model_path.endswith('.h5'), 'Keras model or weights must be a .h5 file.'
            self.model = yolo_body([640, 640, 3], ...)
            self.model.load_weights(self.model_path)
            outputs = Lambda(
                DecodeBox, 
                output_shape = (1,), 
                name = 'yolo_eval',
                arguments = {...}
            )(self.model.output)  # self.model.output输出是三个特征层
            self.yolo_model = Model(self.model.input, outputs)
            tf.saved_model.save(self.yolo_model, "yolo_tflite/yolov4")
    

2、pb文件转换为tflite

 converter = tf.lite.TFLiteConverter.from_saved_model('yolo_tflite/yolov4')
 converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS]
 converter.allow_custom_ops = True
 tflite_model = converter.convert()
 open('yolo_tflite/yolo_fp32.tflite', 'wb').write(tflite_model)
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值