tensorflow 将.pb文件量化操作为.tflite

tf版本:1.12.0 gpu版本

1.pb文件量化为.lite文件

import tensorflow as tf

path_to_frozen_graphdef_pb = '/***/****/frozen_eval_graph-245400-20000.pb' #模型路径
converter = tf.contrib.lite.TFLiteConverter.from_frozen_graph(path_to_frozen_graphdef_pb,
                                                            input_arrays=["Placeholder"],#模型输入节点的name
                                                            output_arrays=["head/reg13x13_output/BiasAdd","head/reg26x26/BiasAdd","head/reg52x52/BiasAdd"])#模型输出节点的name,我的是yolov3,所以有三个输出

converter.inference_type = tf.contrib.lite.constants.QUANTIZED_UINT8  #转成uint8格式
converter.quantized_input_stats = {"Placeholder": (0., 1.)}   #改前边"Placeholder"为自己的name,后边(0., 1.)不用动
converter.allow_custom_ops = True
converter.default_ranges_stats = (0, 255) #不动
converter.post_training_quantize = True
tflite_model = converter.convert()
open("/****/eval_graph.tflite", "wb").write(tflite_model)   #前边为要保存的.tflite路径

2.调用.lite文件

interpreter = tf.contrib.lite.Interpreter(model_path="./best_distillation_float_model_folder/pedestrian111.tflite")  #模型文件地址
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()

np_array_crop = np_array_crop/255.0
np_array_crop = np.array(np_array_crop, dtype=np.float32)

start_time = time.time()
interpreter.set_tensor(input_details[0]['index'], [np_array_crop])
interpreter.invoke()

#yolov3行人检测,因此有三个输出。
predictions1 = interpreter.get_tensor(output_details[0]['index'])
predictions2 = interpreter.get_tensor(output_details[1]['index'])
predictions3 = interpreter.get_tensor(output_details[2]['index'])                

pedestrian_res, pedestrian_coord, pedestrian_bboxes, other_coord, \
        other_bboxes, box_offsets, pedestrian_conf = [], [], [], [], [], [], []
for b in range(predictions1.shape[0]):
    boxes = []
    stage_idx = 1
    for pred_input in [predictions1, predictions2, predictions3]:
        boxes_tmp, confs, cls_confs = boxes_in_batch(pred_input[b,:,:,:], 
            net_input_w, net_input_h, num_object, num_class, 
            anchors[(stage_idx - 1) * 2 * num_object : stage_idx * 2 * num_object],
            conf_thresh, class_thresh)
        boxes.extend(boxes_tmp)
        stage_idx = stage_idx +1
    boxes = np.asarray(boxes)
    print("len",len(boxes))
    if boxes.shape[0] > 0:    
        #计算nms非极大抑制值
        nms_box =util.nms(boxes, num_class, nms_thresh)
        print("after nms len",len(nms_box))
        # pedestrian_res = copy.deepcopy(nms_box)
        pedestrian_res = nms_box

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值