tensorflow实现计算模型大小和FLOPS!以及与.ckpt和weight文件大小转换说明!Model‘ object has no attribute ‘get_operations‘!

一、计算参数量和FOPS代码

本人觉得下面这个代码出错率少,因为沾到.pb文件,有些时候你总是要出点错,比如说:

AssertionError: output is not in graph

而且你去把节点打印出来后,把输出节点换上,还是错,不可思议。
算了,你直接把你的模型输入,下面的框框!

    # 模型开始处××××××××××××××××××××××××××××
    # ***** (1) Create Graph *****
    input_data = tf.Variable(initial_value=tf.random_normal([1, 416,416,3]))
    route_1, route_2, input_data = backbone.darknet53(input_data, True)
   # 模型结束××××××××××××××××××××××××××××

只是可能比较麻烦


from tensorflow.python.framework import graph_util
from tensorflow.contrib.layers import flatten
import numpy as np
import tensorflow as tf
# 自己函数需要用到的函数

import core.utils as utils
import core.common as common
import core.backbone as backbone
from core.config import cfg


def stats_graph(graph):
    flops = tf.profiler.profile(graph, options=tf.profiler.ProfileOptionBuilder.float_operation())
    params = tf.profiler.profile(graph, options=tf.profiler.ProfileOptionBuilder.trainable_variables_parameter())
    print('GFLOPs: {};    Trainable params: {}'.format(flops.total_float_ops / 1000000000.0, params.total_parameters))


def load_graph(frozen_graph_filename):
    # We load the protobuf file from the disk and parse it to retrieve the
    # unserialized graph_def
    with tf.gfile.GFile(frozen_graph_filename, "rb") as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())

    # Then, we import the graph_def into a new Graph and return it
    with tf.Graph().as_default() as graph:
        # The name var will prefix every op/nodes in your graph
        # Since we load everything in a new graph, this is not needed
        tf.import_graph_def(graph_def, name="prefix")
    return graph


with tf.Graph().as_default() as graph:
    # 模型开始处××××××××××××××××××××××××××××
    # ***** (1) Create Graph *****
    input_data = tf.Variable(initial_value=tf.random_normal([1, 416,416,3]))
    route_1, route_2, input_data = backbone.darknet53(input_data, True)

    # 模型结束××××××××××××××××××××××××××××

    print('stats before freezing')
    stats_graph(graph)
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        # ***** (2) freeze graph *****
        output_graph = graph_util.convert_variables_to_constants(sess, graph.as_graph_def(), ['prefix/pred_lbbox/concat_2'])
        with tf.gfile.GFile('graph.pb', "wb") as f:
            f.write(output_graph.SerializeToString())

def count_flops(graph):
    flops = tf.profiler.profile(graph, options=tf.profiler.ProfileOptionBuilder.float_operation())
    print('FLOPs: {}'.format(flops.total_float_ops))

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    graph = load_graph('./yolov3_coco.pb')
    stats_graph(graph)

把对应的路径改了,就可以运行。
结果如下:

    stem_block/stem_block_conv1_l0 (--/528 params)
   stem_block/stem_block_conv1_l0/BatchNorm (--/16 params)
     stem_block/stem_block_conv1_l0/BatchNorm/beta (16, 16/16 params)
   stem_block/stem_block_conv1_l0/weights (1x1x32x16, 512/512 params)
 stem_block/stem_block_conv1_l1 (--/4.64k params)
   stem_block/stem_block_conv1_l1/BatchNorm (--/32 params)
     stem_block/stem_block_conv1_l1/BatchNorm/beta (32, 32/32 params)
   stem_block/stem_block_conv1_l1/weights (3x3x16x32, 4.61k/4.61k params)
 stem_block/stem_block_output (--/2.08k params)
   stem_block/stem_block_output/BatchNorm (--/32 params)
     stem_block/stem_block_output/BatchNorm/beta (32, 32/32 params)
   stem_block/stem_block_output/weights (1x1x64x32, 2.05k/2.05k params)

======================End of Report==========================
GFLOPs: 7.98617009;    Trainable params: 7958576

如果你要用点.ckpt去计算模型的参数量,可以使用如下代码:

from tensorflow.python import pywrap_tensorflow
import tensorflow as tf
import os
import numpy as np
model_dir = "./checkpoint/"
checkpoint_path = os.path.join(model_dir, "yolov3_coco_demo.ckpt")
reader = pywrap_tensorflow.NewCheckpointReader(checkpoint_path)
var_to_shape_map = reader.get_variable_to_shape_map()
total_parameters = 0
for key in var_to_shape_map:#list the keys of the model
    # print(key)
    # print(reader.get_tensor(key))
    shape = np.shape(reader.get_tensor(key))  #get the shape of the tensor in the model
    shape = list(shape)
    # print(shape)
    # print(len(shape))
    variable_parameters = 1
    for dim in shape:
        # print(dim)
        variable_parameters *= dim
    # print(variable_parameters)
    total_parameters += variable_parameters

print(total_parameters)

结果如下:

62001757

Process finished with exit code 0

二、模型参数的换算

上面所求yolov3模型的参数是:62001757
这里是参数的个数,请各位切记!!!!
然后你从官方下来的文件里面说明:yolov3.weight
在这里插入图片描述
还有点.ckpt文件里面保存模型参数大小的文件:
在这里插入图片描述
还有pb文件:
在这里插入图片描述
都是242kb左右!怎么会这样了?
在yolov3里面保存参数的类型是.float32格式,这里面是4个字节,每个字节8位。那么1MB(1024个参数)的参数,就是占用内存为4MB(1B=8bit)。
那么62001757占用内存大小为:62001757*4 = 248007028(单位是b)
最后换成:KB单位,
模型大小为:248007028242194.363281。
模型大小就是这样换算来的。

三、报错 AttributeError: ‘Model’ object has no attribute ‘get_operations’

如果在计算FLOPs的时候,模型报错的话,那么你可能直接把.pb文件输入了:

def count_flops(graph):
    flops = tf.profiler.profile(graph, options=tf.profiler.ProfileOptionBuilder.float_operation())
    print('FLOPs: {}'.format(flops.total_float_ops))
stats_graph('./yolov3_coco.pb')

这里代码需要修改,这样修改!!!

def count_flops(graph):
    flops = tf.profiler.profile(graph, options=tf.profiler.ProfileOptionBuilder.float_operation())
    print('FLOPs: {}'.format(flops.total_float_ops))
    
stats_graph('./yolov3_coco.pb')

def load_graph(frozen_graph_filename):
    # We load the protobuf file from the disk and parse it to retrieve the
    # unserialized graph_def
    with tf.gfile.GFile(frozen_graph_filename, "rb") as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())

    # Then, we import the graph_def into a new Graph and return it
    with tf.Graph().as_default() as graph:
        # The name var will prefix every op/nodes in your graph
        # Since we load everything in a new graph, this is not needed
        tf.import_graph_def(graph_def, name="prefix")
    return graph

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    graph = load_graph('./yolov3_coco.pb')
    stats_graph(graph)

代码就可以正常运行了!!!

  • 4
    点赞
  • 11
    收藏
    觉得还不错? 一键收藏
  • 5
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 5
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值