tensorflow-pb导出与导入

4 篇文章 0 订阅
2 篇文章 0 订阅

网络保存

  1. pb 文件:pb文件的网络结构和权重都是固化的不可更改
  2. meta等文件: 文件读入后可以再重新训练。

1. pb文件固化-python

为了实现用python训练网络,用C++直接运行网络,因此需要将网络固化成pb文件输出。
我的pb固化是在前向测试的时候实现的。代码如下

  with tf.name_scope("Input"): #
      input_images = tf.placeholder(dtype=tf.float32, shape=[None, resize_height, resize_width, depths],
                                    name='input_image')  # 由于图像存储的原因,灰度图维度较少一维
      # is_training = tf.placeholder(tf.bool, name='is_training')
  with tf.name_scope("net"):
      out = net.inference(inputs=input_images, num_classes=labels_nums, dropout_keep_prob=1.0,
                          is_training=False)

  # 将输出结果进行softmax分布,再求最大概率所属类别
  with tf.name_scope('Output'):
      score = tf.nn.softmax(out,name='predict')
      class_id = tf.argmax(score, 1)

  sess = tf.InteractiveSession()
  sess.run(tf.global_variables_initializer())
  saver = tf.train.Saver()
  saver.restore(sess, models_path)
  images_list=glob.glob(os.path.join(image_dir,'*.bmp'))
  for i in range(len(images_list)):
      processed_img = read_image(images_list[i], resize_height, resize_width, normalization=True)
      input_imge = sess.run(processed_img)
      input_imge = input_imge[np.newaxis, :]
      pre_score,pre_label = sess.run([score,class_id], feed_dict={input_images:input_imge})
      max_score=pre_score[0,pre_label]
      file_logger.info("![]({0}){1} pre labels:{2} score: {3}{4}".format(images_list[i],"  \r",labels[pre_label], max_score,"  \r"))
  if is_savePb:
      ##TODO
      ##存成pbfile, 注意必须为List才能存
      outPut_nodeName = []
      outPut_nodeName.append('Output/predict')   #网络输出的节点
      output_graph_def = tf.graph_util.convert_variables_to_constants(
          sess,  # The session is used to retrieve the weights
          tf.get_default_graph().as_graph_def(),  # The graph_def is used to retrieve the nodes
          output_node_names=outPut_nodeName  # The output node names are used to select the usefull nodes
      )

      with tf.gfile.GFile(output_graph, "wb") as f:
          f.write(output_graph_def.SerializeToString())
  sess.close()

需要注意的点是:

  • input 和output要定义好,且input要用placeholder
  • output节点的节点名前面要带上图名,而且outPut_nodeName的类型必须为List,即使只有一个节点
  • 要记住输入输出的节点名,恢复的时候也是用同样的名字

2. pb文件恢复-python版本

代码如下:

def Predict(image_path, PbFileName, resize_height,resize_width):
    """
    测试使用pbFile,单次读取一张图片并输出结果,
    注意要给定输入节点的名称,用sess.graph.get_tensor_by_name("input:0"),节点格式为"nodeName:index"
    :param image_path: 图像路径
    :param PbFileName: pb图像路径
    :param resize_height: 图像高度
    :param resize_width: 图像宽度
    :return:
    """
    #读取pbFile
    labels_filename = 'dataset/label.txt'
    output_graph_def = tf.GraphDef()
    labels = np.loadtxt(labels_filename, str, delimiter='\t')
    with open(PbFileName, "rb") as f:
        output_graph_def.ParseFromString(f.read())
        tf.import_graph_def(output_graph_def, name="")
    init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())
    with tf.Session() as sess:
        sess.run(init_op)
        #TODO REDEFINE THE INPUT TENSOR IF NESSASARY
        # 定义输入的张量名称,对应网络结构的输入张量
        # input:0作为输入图像,keep_prob:0作为dropout的参数,测试时值为1,is_training:0训练参数
        input_image_tensor = sess.graph.get_tensor_by_name("Input/input_image:0")
        output_tensor_name = sess.graph.get_tensor_by_name("Output/predict:0")
        input_imge = read_image(image_path, resize_height, resize_width, normalization=True)
        input_imge = sess.run(input_imge)
        input_imge = input_imge[np.newaxis, :]

        score = sess.run(output_tensor_name, feed_dict={input_image_tensor: input_imge
                                                      })

        print("score:{}".format(score))
        class_id = tf.argmax(score, 1)
        print( "pre class is :{}".format(labels[sess.run(class_id)]))

需要注意的点是

  • 使用sess.graph.get_tensor_by_name(“Input/input_image:0”)的时候除了节点名还要加上":0",使用C++时候不用

3. pb文件导入-C++版本

//图的导入
Status LoadGraph(const string& graph_file_name,
                 std::unique_ptr<tensorflow::Session>* session) {
  tensorflow::GraphDef graph_def;
  Status load_graph_status =
      ReadBinaryProto(tensorflow::Env::Default(), graph_file_name, &graph_def);
  if (!load_graph_status.ok()) {
    return tensorflow::errors::NotFound("Failed to load compute graph at '",
                                        graph_file_name, "'");
  }
  session->reset(tensorflow::NewSession(tensorflow::SessionOptions()));
  Status session_create_status = (*session)->Create(graph_def);
  if (!session_create_status.ok()) {
    return session_create_status;
  }
  return Status::OK();
}
 //定义session
	std::unique_ptr<tensorflow::Session> session;
    string graph_path = tensorflow::io::JoinPath(root_dir, graph);
    Status load_graph_status = LoadGraph(graph_path, &session);    //输入图
    if (!load_graph_status.ok()) {
      LOG(ERROR) << load_graph_status;
      return -1;
    }
  //定义输入输出
     std::vector<Tensor> resized_tensors;
    string image_path = tensorflow::io::JoinPath(root_dir, image);
    Status read_tensor_status =
        ReadTensorFromImageFile(image_path, input_height, input_width, input_mean,
                                input_std, &resized_tensors);
    if (!read_tensor_status.ok()) {
      LOG(ERROR) << read_tensor_status;
      return -1;
    }
    const Tensor& resized_tensor = resized_tensors[0];

    // Actually run the image through the model.
    std::vector<Tensor> outputs;
    Status run_status = session->Run({{input_layer, resized_tensor}},
                                     {output_layer}, {}, &outputs);
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值