Tensorflow替换静态图中的OP

import tensorflow as tf
import collections
from tensorflow.core.framework import tensor_shape_pb2

# 读取模型
graph_def = tf.GraphDef()
with tf.gfile.FastGFile('./pb/model.pb', 'rb') as f:
    graph_def.ParseFromString(f.read())


# 统计图中的node,保存为map.其中 key : value = op.name : op
input_node_map = {}
for node in graph_def.node:
    if node.name not in input_node_map.keys():
        input_node_map[node.name] = node
    else:
        raise ValueError("Duplicate node names detected for ", node.name)

# 统计每一个op被使用的次数
node_reference_count = collections.defaultdict(int)
output_node_names = ['xnet/Softmax']
for node in graph_def.node:
    for input_name in node.input:
        stripped_name = input_name
        node_reference_count[stripped_name] += 1
for output_name in output_node_names:
    node_reference_count[output_name] += 1

# 删除old_op
old_op = input_node_map['xnet/Layer_Conv_1/Conv2D']
node_reference_count['xnet/Layer_Conv_1/Conv2D'] -= 1

# 创建新的op
new_node = tf.NodeDef()
new_node.op = 'Conv2D'
new_node.name = 'new_Conv_1'
for input_name in old_op.input:
    new_node.input.extend([input_name])
new_node.attr["T"].CopyFrom(tf.AttrValue(type=tf.float32.as_datatype_enum))  # (old_op.attr["T"])
new_node.attr["use_cudnn_on_gpu"].CopyFrom(tf.AttrValue(b=1))  # (old_op.attr["use_cudnn_on_gpu"])
stride_list = [1, 2, 2, 1]
new_node.attr["strides"].CopyFrom(tf.AttrValue(list=tf.AttrValue.ListValue(i=stride_list)))  # (old_op.attr["strides"])
new_node.attr["padding"].CopyFrom(tf.AttrValue(s=b'VALID'))  # (old_op.attr["padding"])

# 创建const类型的op,仅作为测试,本实验中不添加入graph
new_const = tf.NodeDef()
new_const.op = 'Const'
new_const.name = 'new_Const'
new_const.attr['dtype'].CopyFrom(tf.AttrValue(type=tf.float32.as_datatype_enum))
new_const.attr['value'].CopyFrom(
    tf.AttrValue(tensor=tf.make_tensor_proto([4, 5, 0, 0, 8, 0, 7, 0], tf.float32, [4, 2])))
new_const.attr['_output_shapes'].CopyFrom(
    tf.AttrValue(list=tf.AttrValue.ListValue(shape=[tensor_shape_pb2.TensorShapeProto(
        dim=[tensor_shape_pb2.TensorShapeProto.Dim(size=4), tensor_shape_pb2.TensorShapeProto.Dim(size=2)])])))

# 将new_node作为输入赋值给图中节点
for node in graph_def.node:
    if old_op.name in node.input:
        for i, name in enumerate(node.input):
            if name == old_op.name:
                node.input[i] = new_node.name
print('success_1')

# 定义一个新图
graph_def_new = tf.GraphDef()
for node in graph_def.node:
    if node_reference_count[node.name] < 1:
        continue
 new = tf.NodeDef()
    new.CopyFrom(node)
    graph_def_new.node.extend([new])
graph_def_new.node.extend([new_node])
# graph_def_new.node.extend([new_const])

# 将新图注入到默认的Graph中
tf.import_graph_def(graph_def_new, name='')  # Imports `graph_def` into the current default `Graph`

# 测试案例
with tf.Session() as sess:
    tf.train.write_graph(sess.graph_def, logdir='./pb', name='graph_def_new.pb')

OP的信息:

name: "xnet/Layer_FC_32/xw_plus_b"

op: "BiasAdd"

input: "xnet/Layer_FC_32/xw_plus_b/MatMul"

input: "xnet/Layer_FC_32/biases/read"

attr

{

key: "T"

value {type: DT_FLOAT}

}

attr

{

key: "data_format"

value {s: "NHWC"}

}

在tensorflow中,OP主要包括以下信息:name, op , input, attr

  1. name--类型string。 在模型定义的时候由工程师定义,如果工程师没有定义的话会自动的利用op作为其值

  2. op--类型string。表示这是一个什么op,比如加减乘除,当在运行的时候,编译器会更具op调用相应的算子来做计算

  3. input--类型list.。列表中包含了该节点输入,是有序的,不可以被assign

  4. attr--类型map。map中的key和value一般是指该OP的配置信息

OP的操作:

1、op信息获取

1. 通过Graph获取op

op = tf.get_default_graph().get_Operations()

print(op[0])

print(op[0].name)

# 如果想获得属性或者input信息需要如下写法

print(op[0].node_def.attr)

2.通过Graph_def获取op

op = graph_def.node

print(op[0].name)

print(op[0].input)

2、op的创建

在构建新的op的时候需要对op的属性比较清楚,对于没有default的属性一定要做好初始化

1.根据已有的op创建新的op

 

new_node = tf.NodeDef()  # 构建一个op对象,所有属性都为空

new_node.op = 'Conv2D'

new_node.name = 'new_Conv_1'

for input_name in old_op.input:  # 原始op的input导入进来

       new_node.input.extend([input_name])

new_node.attr["T"].CopyFrom(old_op.attr["T"])

new_node.attr["use_cudnn_on_gpu"].CopyFrom(old_op.attr["use_cudnn_on_gpu"])

new_node.attr["strides"].CopyFrom(old_op.attr["strides"])

new_node.attr["padding"].CopyFrom(old_op.attr["padding"])

 

2.创建一个自定义的op

new_op = tf.NodeDef()

new_op.op = "Const"

new_op.name = conv_op.name

new_op.attr["dtype"].CopyFrom(tf.AttrValue( type=tf.int32.as_datatype_enum))

new_op.attr["value"].CopyFrom(tf.AttrValue(tensor=tf.make_tensor_proto([0, 0, 0, 0, 0, 0, 0, 0],                                                                                                tf.int32, [4, 2])))

 

OP中attr为map,每一个map中key为字符串,value为的类型由下面9种,每种对应的原型如下表所示:

repeated bytes s = 2; // "list(string)"

repeated int64 i = 3 [packed = true]; // "list(int)"

repeated float f = 4 [packed = true]; // "list(float)"

repeated bool b = 5 [packed = true]; // "list(bool)"

repeated DataType type = 6 [packed = true]; // "list(type)"

repeated TensorShapeProto shape = 7; // "list(shape)"

repeated TensorProto tensor = 8; // "list(tensor)"

repeated NameAttrList func = 9; // "list(attr)"

list 也为value的一种类型

 

每一种类型初始化方式:

CopyFrom(tf.AttrValue( s=b'hello,world'))

CopyFrom(tf.AttrValue( i=88 ))

CopyFrom(tf.AttrValue( f=88.0 ))

CopyFrom(tf.AttrValue( b=1/0 ))

new_op.attr["dtype"].CopyFrom(tf.AttrValue( type=tf.int32.as_datatype_enum))

from tensorflow.core.framework import tensor_shape_pb2

tensor_shape_pb2.TensorShapeProto(dim=[tensor_shape_pb2.TensorShapeProto.Dim(size=-1 if d.value is None else d.value) for d in dims])

new_op.attr["value"].CopyFrom(tf.AttrValue(tensor=tf.make_tensor_proto([0, 0, 0, 0, 0, 0, 0, 0],                                                                                                tf.int32, [4, 2])))

func目前没有没有遇到过

stride_list = [1, 2, 2, 1]

new_node.attr["strides"].CopyFrom(tf.AttrValue(list=tf.AttrValue.ListValue(i=stride_list)))

  • 1
    点赞
  • 7
    收藏
    觉得还不错? 一键收藏
  • 2
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值