TFGraphCliper——TF Graph 的修改方法

TFGraphCliper——TF Graph 的修改方法

在移动平台上,我们做 AI 网络的移植落地,总会遇到各式各样的问题。这些问题,大多跟现在深度学习移动框架的不兼容、不完善有关。最近,某家新发布的移动框架对 keras 的 Reshape 算子( Tensorflow backend ) 不兼容,使得我再次碰到这种窘境,需要对一个已经训练好的 frozen graph (也就是 .pb 文件) 进行针对性修改。尽管 Tensorflow 相对完善的框架体系,已经提供了 graph_transformer 工具,但也不好覆盖到这种生僻的问题。因此,我决定这次总结一个相对万金油的方法,并形成一个简单的 demo 工具库,尽可能解决这类问题。

这是我修改工具的链接:SunAriesCN/TFGraphCliper

由于时间原因,我只写了英文说明,有中文需求的朋友欢迎留言。
下面是我直接贴原文:

First, we should have a look at my ‘test.pb’ graph by tensorboard

Generate the event directory of ‘test.pb’:

$ python tools/import_pb_to_tensorboard.py --model_dir='examples/test.pb' --log_dir='log_test'

Visualize by running:

$ tensorboard --logdir=log_test

这里写图片描述

Assume that we want to concatenate our Placeholder input and Relu6, so we need to add a concat op into ‘test.pb’. However, the ‘test.pb’ file is a binary format, it’s difficult for human to read,to say nothing to change. For that, I think we can convert it into ‘text.pbtxt’ the easier format for us to know what it is, and edit it.

import tensorflow as tf
graph_path = '../examples/test.pb'
with tf.Graph().as_default() as g_1:
    g1_def = tf.GraphDef()
    with open(graph_path, 'rb') as f:
        g1_def.ParseFromString(f.read())
        _ = tf.import_graph_def(g1_def, name="")
        sess = tf.Session()
        tf.train.write_graph(sess.graph, '../examples','test.pbtxt')

Run the codes above, and read our test.pbtxt

$ vim examples/test.pbtxt

We can see somethings as following:

node {
  name: "Placeholder"
  op: "Placeholder"
  attr {
    key: "dtype"
    value {
      type: DT_FLOAT
    }   
  }
  attr {
    key: "shape"
    value {
      shape {
        dim {
          size: -1
        }   
        dim {
          size: 256 
        }   
        dim {
          size: 256 
        }   
        dim {
          size: 3
        }   
      }   
    }   
  }
}
node {
  name: "Const"
  op: "Const"
  attr {
    key: "dtype"
    value {
      type: DT_FLOAT
    }   
  }
  attr {
    key: "value"

......

And then, we also want to know what is the ‘concat’ op look like in the ‘.pbtxt’ file, so we can make a small demo for it as following:

with tf.Graph().as_default() as g2:
    x = tf.placeholder(dtype=tf.float32, shape=[None, 256, 256, 3])
    # the relu6 output channel size is 13, you can get this info from tensorboard.
    y = tf.placeholder(dtype=tf.float32, shape=[None, 256, 256, 13])
    z = tf.concat([x,y],axis=-1)
    sess1 = tf.Session()
    tf.train.write_graph(sess1.graph, '../examples','demo.pbtxt')

Open the ‘demo.pbtxt’ you may see somethings as following:

node {
  name: "concat/axis"
  op: "Const"
  attr {
    key: "dtype"
    value {
      type: DT_INT32
    }   
  }
  attr {
    key: "value"
    value {
      tensor {
        dtype: DT_INT32
        tensor_shape {
        }   
        int_val: -1
      }   
    }   
  }
}
node {
  name: "concat"
  op: "ConcatV2"
  input: "Placeholder"
  input: "Placeholder_1"
  input: "concat/axis"
  attr {
    key: "N" 
    value {
      i: 2
    }   
  }
  attr {
    key: "T" 
    value {
      type: DT_FLOAT
    }   
  }
  attr {

.........

So, what can we do next? I think it’s quite clear for everyone, we just copy the code about ‘concat’ op into ‘test.pbtxt’ and modify the inputs info, then we will get our target graph file. I give you some tips of code in ‘test.pbtxt’ as following:

......

node {
  name: "concat"
  op: "ConcatV2"
  input: "Relu6"
  input: "Placeholder"
  input: "concat/axis"
  attr {
    key: "N" 
    value {
      i: 2
    }   
  }
  attr {
    key: "T" 
    value {
      type: DT_FLOAT
    }   
  }
  attr {
    key: "Tidx"
    value {
      type: DT_INT32
    }   
  }
}

......

At last, we need to convert our ‘test.pbtxt’ back into ‘test.pb’ file, and to tell the cliped one, the new ‘test.pb’ I will rename it as ‘test_cliped.pb’

from google.protobuf import text_format
graph_path = '../examples/test.pbtxt'
with tf.Graph().as_default() as g_1:
    g1_def = tf.GraphDef()
    with open(graph_path, 'rb') as f:
        text_format.Merge(f.read(), g1_def)
        _ = tf.import_graph_def(g1_def, name="")
        sess = tf.Session()
        tf.train.write_graph(sess.graph, '../examples','test_cliped.pb', as_text=False)

Using tools to check our modification.

$ python tools/import_pb_to_tensorboard.py --model_dir='examples/test_cliped.pb' --log_dir='log_test_cliped'
$ tensorboard --logdir=log_test_cliped

这里写图片描述

Bravo!

But it may not really mean a happy ending before we check the cliped graph work well as our expectation. Actually, we may meet lots of probelms about the wrong result from our cliped graph, and we should keep checking and refining our graph for our expectaton.

FIN.

  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 2
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值