Tensorflow 静态图PB模型修改(OP修改)

def load_pb_graph(path):
    with tf.gfile.GFile(path, "rb") as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())
    with tf.Graph().as_default() as g:
        tf.import_graph_def(graph_def, name=None)
    return g


model_filename = '111.pb'
g = load_pb_graph(model_filename)
#加载原图完毕

new_model = tf.GraphDef()
with tf.Session(graph=g) as sess:
    for n in sess.graph_def.node:
        if n.name in ['import/input_ids','import/input_mask', 'import/token_type_ids']:
            nn = new_model.node.add()
            nn.op = n.op
            nn.name = n.name
            nn.attr['dtype'].CopyFrom(tf.AttrValue(type=tf.int32.as_datatype_enum))
            s = tensor_shape_pb2.TensorShapeProto()
            d1 = tensor_shape_pb2.TensorShapeProto.Dim()
            d2 = tensor_shape_pb2.TensorShapeProto.Dim()
            d1.size = 1
            d2.size = 7
            s.dim.extend([d1,d2])
            nn.attr['shape'].shape.CopyFrom(s)
            for i in n.input:
                nn.input.extend([i])
        else:
            new_model.node.append(n)
            # nn = new_model.node.add()
            # nn.CopyFrom(n) 太过于耗时,可以使用append直接加入old节点

print('*'*100)
#将新图注入到默认的Graph中
#tf.import_graph_def(new_model, name='')  # Imports `graph_def` into the current default `Graph`

# 测试案例
with tf.Session() as sess:
    tf.train.write_graph(new_model, logdir='./', name='graph_def_new.pb',  as_text=False)

pb输入会有import,ckpt就是node的名字。
这里将输入的128维度替换成7维,主要是为了测时间,如果是标准bert,后续还要修改reshape,这里对后续的所有操作融合成了一个OP,所以仅仅需要修改输入的dim即可。

最后保存pb时候, tf.train.write_graph,特别注意可能特别慢,因为需要把as_text设置成False,否则图被写成一个 text proto。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值