python解析tflite模型文件

目的

解析flatbuffer格式的tflite文件,转成可读的python dict格式,并可描述模型完整推理流程。

背景

tf.lite.Interpreter可以读tflite模型,但是其python接口没有描述模型结构(op node节点间的连接关系)

比如,interpreter.get_tensor_details()获取的信息,如下

[{'name': 'input_13',
  'index': 0,
  'shape': array([  1,   1, 240,   1], dtype=int32),
  'shape_signature': array([ -1,   1, 240,   1], dtype=int32),
  'dtype': numpy.float32,
  'quantization': (0.0, 0),
  'quantization_parameters': {'scales': array([], dtype=float32),
   'zero_points': array([], dtype=int32),
   'quantized_dimension': 0},
  'sparsity_parameters': {}},
 {'name': 'model_12/dense_144/BiasAdd/ReadVariableOp/resource',
  'index': 1,
  'shape': array([2], dtype=int32),
  'shape_signature': array([2], dtype=int32),
  'dtype': numpy.float32,
  'quantization': (0.0, 0),
  'quantization_parameters': {'scales': array([], dtype=float32),
   'zero_points': array([], dtype=int32),
   'quantized_dimension': 0},
  'sparsity_parameters': {}},
 {'name': 'model_12/dense_145/BiasAdd/ReadVariableOp/resource',
....

按本文方式,可以直接获取节点的op参数、输入、输出序号,如下

subg
{'inputs': [0],
 'name': [109, 97, 105, 110],
 'operators': [{'builtin_options': {'dilation_h_factor': 1,
    'dilation_w_factor': 1,
    'fused_activation_function': 1,
    'padding': 0,
    'stride_h': 1,
    'stride_w': 1},
   'builtin_options_type': 1,
   'custom_options': None,
   'custom_options_format': 0,
   'inputs': [0, 52, 68],   //输入节点
   'intermediates': None,
   'mutating_variable_inputs': None,
   'opcode_index': 0,
   'outputs': [82]},        //输出节点
  {'builtin_options': {'depth_multiplier': 1,
...

tensors
[{'buffer': 1,
  'is_variable': False,
  'name': [105, 110, 112, 117, 116, 95, 49, 51],
  'quantization': {'details': None,
   'details_type': 0,
   'max': None,
   'min': None,
   'quantized_dimension': 0,
   'scale': None,
   'zero_point': None},
  'shape': [1, 1, 240, 1],
  'shape_signature': [-1, 1, 240, 1],
  'sparsity': None,
  'type': 0},
 {'buffer': 2,
...

方法

#/tensorflow/lite/tools/visualize.py
import re
from tensorflow.lite.python import schema_py_generated as schema_fb
 
def BuiltinCodeToName(code):
    """Converts a builtin op code enum to a readable name."""
    for name, value in schema_fb.BuiltinOperator.__dict__.items():
        if value == code:
            return name
    return None
def CamelCaseToSnakeCase(camel_case_input):
    """Converts an identifier in CamelCase to snake_case."""
    s1 = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", camel_case_input)
    return re.sub("([a-z0-9])([A-Z])", r"\1_\2", s1).lower()
def FlatbufferToDict(fb, preserve_as_numpy):
    if isinstance(fb, int) or isinstance(fb, float) or isinstance(fb, str):
        return fb
    elif hasattr(fb, "__dict__"):
        result = {}
        for attribute_name in dir(fb):
            attribute = fb.__getattribute__(attribute_name)
            if not callable(attribute) and attribute_name[0] != "_":
                snake_name = CamelCaseToSnakeCase(attribute_name)
                preserve = True if attribute_name == "buffers" else preserve_as_numpy
                result[snake_name] = FlatbufferToDict(attribute, preserve)
        return result
    elif isinstance(fb, np.ndarray):
        return fb if preserve_as_numpy else fb.tolist()
    elif hasattr(fb, "__len__"):
        return [FlatbufferToDict(entry, preserve_as_numpy) for entry in fb]
    else:
        return fb
def CreateDictFromFlatbuffer(buffer_data):
    model_obj = schema_fb.Model.GetRootAsModel(buffer_data, 0)
    model = schema_fb.ModelT.InitFromObj(model_obj)
    return FlatbufferToDict(model, preserve_as_numpy=False)

转换

# Read the model.
with open('xxx.tflite', 'rb') as f:
    model_buffer = f.read()

#后面获取到tensor id后,通过interpreter.get_tensor即可拿到tensor值
interpreter = tf.lite.Interpreter(model_content=model_buffer)
interpreter.allocate_tensors()


data = CreateDictFromFlatbuffer(model_buffer)
op_codes = data['operator_codes']  #支持/注册的op
subg = data['subgraphs'][0] #模型结构描述,具体的op构成
tensors = subg['tensors'] #tensor描述, 主要有layer参数、权重


for layer in subg['operators']:
    #layer name
    op_idx = layer['opcode_index']
    op_code = op_codes[op_idx]['builtin_code']
    layer_name = BuiltinCodeToName(op_code) 
    
    #layer param
    layer_param = layer['builtin_options']

    #layer input/output idx
    input_tensor_idx = layer['inputs']
    output_tensor_idx = layer['outputs']
    
    #input
    input_idx = input_tensor_idx[0]
    
    #filter weight
    weight_idx = input_tensor_idx[1]
    weight = interpreter.get_tensor(weight_idx) #用interpreter获取具体的权重数值
    filters = tensors[weight_idx]['shape'][0] #卷积核尺寸

    #filter bias
    bias_idx = input_tensor_idx[2]

关于Tensor

上述方法在取tensor数值时,用了interpreter.get_tensor(idx)的方式。 实际上tensor数值也可以从data['buffers']中获取,只不过data['buffers']将tensor解析成uint8_t了。

tensors[52]
{'buffer': 53,
 'is_variable': False,
 'name': [109,
...
}

interpreter.get_tensor(52) //32位浮点
array([[[[ 0.89609855],
         [-0.76255393],
         [ 0.2671022 ]]],
...

data['buffers'][53]['data']  //8位
array([183, 102, 101,  63, 188,  54,  ...

可以自己验证一个数试试

uint8_t a[] = {183, 102,101,63};
printf("%f\n", *(float*)a); // = 0.896099

1.tensor_idx : 每个operator里标注的input、output索引(同netron里显示的location) 

                        buffer = interpreter.get_tensor(tensor_idx) //得到数值

                        tensors = subg['tensors'] 

                        tensor = tensors[tensor_idx] //得到一个Tensor对象

2.buffer_idx:Tensor对象中表述数据位置的索引

                        buffer_idx = tensor['buffer']

                        buffer = data['buffers'][buffer_idx] //得到数值

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值