记录tensorflow、caffe模型文件的获取方式

在学习tensorflow模型时,经常需要用到xx.pb网络模型文件。由于自己是刚学tf,因此经常遇到找不到pb文件的情况,下面就将自己找到的一些方法共享给大家。

 

1、华为云modelzoo共享的pb文件

https://www.huaweicloud.com/ascend/resources/modelzoo?ticket=ST-2203190-w5VbNKUD4ZleGiBvoQk3ygqA-sso&locale=zh-cn

在华为云共享的这些modelzoo资源中,选择ATC_XX开始的这些资源,就能直接下载到pb文件。

缺点:提供的pb文件比较少,只有常见的resnet50,VGG16,VGG19等20个左右网络

 

2、github共享的tf models

https://github.com/tensorflow/models/tree/master/research/slim

通过下载xx.tar.gz可以得到对应模型的ckpt文件。因此需要自己写代码才能将ckpt文件转换成Pb文件。

ckpt转换pb代码的python脚本,我这边是参考了大神博客:《.ckpt、.pb、.pbtxt模型相互转换》,进行了简单修改。

示例代码如下:

# -*- coding: utf-8 -*-
import os
import tensorflow as tf
from google.protobuf import text_format
from tensorflow.python.platform import gfile
from tensorflow.python import pywrap_tensorflow
from tensorflow.python.framework import graph_util


# print ckpt_node_name
def ckpt_node_name(filename):
    checkpoint_path = os.path.join(filename)
    reader = pywrap_tensorflow.NewCheckpointReader(checkpoint_path)
    var_to_shape_map = reader.get_variable_to_shape_map()
    for key in var_to_shape_map:
        print('tensor_name: ', key)


# convert .ckpt to .pb to freeze a trained model
def convert_ckpt_to_pb(filename1, filename2):
    # filename1 is a .meta file
    saver = tf.train.import_meta_graph(filename1, clear_devices=True)
    graph = tf.get_default_graph()
    input_graph_def = graph.as_graph_def()
    with tf.Session() as sess:
        saver.restore(sess, filename1)
        # you need to change the output node name ['embeddings'] to your model's real name.
        output_graph_def = graph_util.convert_variables_to_constants(sess, input_graph_def, ['output_node_name'])
        with tf.gfile.GFile(filename2, "wb") as f:
            f.write(output_graph_def.SerializeToString())


# print pb_node_name
def pb_node_name(filename):
    def create_graph():
        with tf.gfile.FastGFile(filename, 'rb') as f:
            graph_def = tf.GraphDef()
            graph_def.ParseFromString(f.read())
            tf.import_graph_def(graph_def, name='')

    create_graph()
    tensor_name_list = [tensor.name for tensor in tf.get_default_graph().as_graph_def().node]
    for tensor_name in tensor_name_list:
        print(tensor_name, '\n')


def convert_pb_to_pbtxt(filename):
    with gfile.FastGFile(filename, 'rb') as f:
        graph_def = tf.GraphDef()

        graph_def.ParseFromString(f.read())

        tf.import_graph_def(graph_def, name='')

        # tf.train.write_graph(graph_def, './', 'protobuf.pbtxt', as_text=True)
        tf.train.write_graph(graph_def, './tmp', 'LSTM111.pbtxt', as_text=True)
    return


def convert_pbtxt_to_pb(filename):
    """Returns a `tf.GraphDef` proto representing the data in the given pbtxt file.
    Args:
      filename: The name of a file containing a GraphDef pbtxt (text-formatted
        `tf.GraphDef` protocol buffer data).
    """
    with tf.gfile.FastGFile(filename, 'r') as f:
        graph_def = tf.GraphDef()

        file_content = f.read()

        # Merges the human-readable string in `file_content` into `graph_def`.
        text_format.Merge(file_content, graph_def)
        tf.train.write_graph(graph_def, './tmp/train', 'lstm.pb', as_text=False)
    return


if __name__ == '__main__':
    model_path = 'D:\\modelzoo\\inception_v1_2016_08_28\\'
    ckpt_path = model_path + 'inception_v1.ckpt'
    # 输出pb模型的路径
    out_pb_path = model_path + 'inception_v1.pb'
    convert_ckpt_to_pb(ckpt_path, out_pb_path)
    print('Convert .ckpt to .pb has finished')

问题1:

但是在使用这段代码时,遇到如下报错:AttributeError: module 'tensorflow._api.v2.train' has no attribute 'import_meta_graph'

解决方法是将头文件引入修改成如下:

import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()

问题2:

解决完以上问题后,又报错:ttributeError: 'NoneType' object has no attribute 'restore'

关于saver.restore问题,一直找不到解决方法。因此我只有卸载tf v2.0,重装成tf v1.15。

由于直接通过命令:pip install tensorflow==1.15 安装比较慢,因此我是直接安装离线包。对应包路径:https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple/tensorflow/

当然也可以加上清华镜像来安装TF:如:

pip install tensorflow==1.15 -i https://pypi.tuna.tsinghua.edu.cn/simple/

安装成功后可以通过pip show tensorflow来确认版本。

但是发现,即使安装成v1.15版本后,仍然报错。截止目前,我仍未找到解决方法。另外要注意的是,ckpt文件只是checkpoint文件,还需要要对应.meta,但是没找到。可能是我没注意吧。

 

3、caffe网络的模型下载方式

https://github.com/BVLC/caffe/wiki/Model-Zoo

当然还有这个网络也能进行caffemodel下载(但是没有prototxt文件):http://dl.caffe.berkeleyvision.org/

 

3、github 共享的tensorflow网络模型下载

https://github.com/IntelAI/models/tree/master/benchmarks

点击后面的FP32可以看到对应模型下载的方式,推荐使用wget xxx

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
TensorFlowCaffe是两个被广泛应用于深度学习的开源框架。它们都有各自的优缺点如下: TensorFlow的优点: 1. 高度灵活:TensorFlow具有高度可定制的图形计算环境,可以构建各种类型的神经网络模型和算法。 2. 广泛的社区支持:由于其受欢迎程度,有庞大的开发者社区支持,提供大量的文档、教程和示例代码。 3. 高性能计算:TensorFlow通过使用计算图和高效的并行计算技术,能够充分利用多核CPU和GPU加速深度学习模型的训练和推理。 4. 支持多种语言:TensorFlow支持多种主流编程语言,如Python、C++和Java等,提供了多种编程接口,方便开发者使用。 5. 模型可移植性:TensorFlow使用统一的模型表示,能够在不同平台和设备上进行无缝迁移和部署。 TensorFlow的缺点: 1. 学习曲线陡峭:相对于其他框架,TensorFlow的学习曲线可能会较陡峭,对于新手可能需要较长时间来熟悉其概念和使用方法。 2. 繁琐的模型构建:在TensorFlow中,需要手动构建计算图,在一些场景下可能需要编写更多的代码。 3. 运行效率不高:由于其设计的灵活性,TensorFlow在一些小规模的深度学习任务上可能会出现较低的运行效率。 Caffe的优点: 1. 简单易用:Caffe使用简单的配置文件来定义模型和训练过程,对于新手入门较为友好。 2. 高效的内存管理:Caffe通过使用内存映射技术,有效地管理内存使用,适用于处理大规模的数据集。 3. 高速的推理速度:由于其专注于推断(inference)过程,Caffe模型的运行速度方面表现出色。 4. 跨平台支持:Caffe支持多种操作系统,可以在Linux、Windows和Mac等平台上运行。 Caffe的缺点: 1. 灵活性较差:相对于TensorFlowCaffe的灵活性较差,对于一些特殊的网络结构和算法可能需要自己进行扩展和定制。 2. 依赖较多:Caffe对于依赖库的需求较多,需要手动安装和配置依赖项。 3. 社区支持相对较少:相比TensorFlowCaffe的社区支持相对较少,文档和教程相对较少。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值