TensorFlow的建图向来更支持读和添加操作,而不是删除和更新,特别是对于已经经过freeze_graph 的pb (protobuffer)模型。
然而,总可能会遇到情况需要对pb图做修改。TF提供了graph_editor 和graph transform tool,然而,这并不能解决所有问题。
比如,我现在有一个pretrained的pb model,但模型把两个头结点定义成了Variable 而不是placeholder, 导致没法利用TFLite做转化。下面,我便尝试直接在pb图上修改。
要修改的节点如下
name: "previous_state_h"
op: "VariableV2"
}
方法一
人工修改。先把pb读入,再导出成pbtxt格式(人可读),然后手工修改对应节点,最后再读入导出为pb。这种适合为模型不太大且参数不太多的情况,要不一般编辑器吃不消。
方法二
读入后直接修改网络结构。pb模型是有Graph_def定义,其中包含多个node,每个node存在op(操作),attr(属性)等值。
首先,先读入pb模型
graph_def = tf.GraphDef()
with open(model_filename, 'rb') as f:
graph_def.ParseFromString(f.read())
with tf.Graph().as_default() as graph:
tf.import_graph_def(graph_def, name='')
接着,构建新的图
new_model = tf.GraphDef()
with tf.Session(graph=graph) as sess:
for n in sess.graph_def.node:
if n.name in ['previous_state_c','previous_state_h']:
nn = new_model.node.add()
nn.op = 'Placeholder'
nn.name = n.name
nn.attr['dtype'].type = 1
s = tensor_shape_pb2.TensorShapeProto()
d1 = tensor_shape_pb2.TensorShapeProto.Dim()
d2 = tensor_shape_pb2.TensorShapeProto.Dim()
d1.size = 1
d2.size = 16
s.dim._values = [d1,d2]
nn.attr['shape'].shape.CopyFrom(s)
for i in n.input:
nn.input.extend([i])
else:
nn = new_model.node.add()
nn.CopyFrom(n)
其中注意,对简单的值可以直接赋值,比如nn.name 和nn.op。但是,对于复杂的对象,比如shape,则需要先打包好一个shape的实例,再通过CopyFrom赋值,直接赋值是不允许的。
最终,结果,可以看到previous_state_h节点已经修改了。
name: "previous_state_h"
op: "Placeholder"
attr {
key: "dtype"
value {
type: DT_FLOAT
}
}
attr {
key: "shape"
value {
shape {
dim {
size: 1
}
dim {
size: 16
}
}
}
资料:
Replacing a node in a frozen Tensorflow modelstackoverflow.com