tensorflow 模型浮点数计算量和参数量估计

TensorFlow 模型浮点数计算量和参数量统计
2018-08-28

本博文整理了如何对一个 TensorFlow 模型的浮点数计算量(FLOPs)和参数量进行统计。
stats_graph.py

import tensorflow as tf
def stats_graph(graph):
    flops = tf.profiler.profile(graph, options=tf.profiler.ProfileOptionBuilder.float_operation())
    params = tf.profiler.profile(graph, options=tf.profiler.ProfileOptionBuilder.trainable_variables_parameter())
    print('FLOPs: {};    Trainable params: {}'.format(flops.total_float_ops, params.total_parameters))

利用高斯分布对变量进行初始化会耗费一定的 FLOP

C[25,9]=A[25,16]B[16,9] FLOPs=(16+15)×(25×9)=6975FLOPs(inTFstyle)=(16+16)×(25×9)=7200total_parameters=25×16+16×9=544

with tf.Graph().as_default() as graph:
    A = tf.get_variable(initializer=tf.random_normal_initializer(dtype=tf.float32), shape=(25, 16), name='A')
    B = tf.get_variable(initializer=tf.random_normal_initializer(dtype=tf.float32), shape=(16, 9), name='B')
    C = tf.matmul(A, B, name='ouput')
    
    stats_graph(graph)

输出为:
FLOPs: 8288; Trainable params: 544

利用常量初始化器对变量进行初始化不会耗费 FLOP

with tf.Graph().as_default() as graph:
    A = tf.get_variable(initializer=tf.constant_initializer(value=1, dtype=tf.float32), shape=(25, 16), name='A')
    B = tf.get_variable(initializer=tf.zeros_initializer(dtype=tf.float32), shape=(16, 9), name='B')
    C = tf.matmul(A, B, name='ouput')
    
    stats_graph(graph)

输出为:
FLOPs: 7200; Trainable params: 544

Frozen graph

通常我们对耗费在初始化上的 FLOPs 并不感兴趣,因为它是发生在训练过程之前且是一次性的,我们感兴趣的是模型部署之后在生产环境下的 FLOPs。我们可以通过 Freeze 计算图的方式得到除去初始化 FLOPs 的、模型部署后推断过程中耗费的 FLOPs。

from tensorflow.python.framework import graph_util
def load_pb(pb):
    with tf.gfile.GFile(pb, "rb") as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())
    with tf.Graph().as_default() as graph:
        tf.import_graph_def(graph_def, name='')
        return graph
with tf.Graph().as_default() as graph:
    # ***** (1) Create Graph *****
    A = tf.Variable(initial_value=tf.random_normal([25, 16]))
    B = tf.Variable(initial_value=tf.random_normal([16, 9]))
    C = tf.matmul(A, B, name='output')
    
    print('stats before freezing')
    stats_graph(graph)
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        # ***** (2) freeze graph *****
        output_graph = graph_util.convert_variables_to_constants(sess, graph.as_graph_def(), ['output'])
        with tf.gfile.GFile('graph.pb', "wb") as f:
            f.write(output_graph.SerializeToString())
# ***** (3) Load frozen graph *****
graph = load_pb('./graph.pb')
print('stats after freezing')
stats_graph(graph)

输出为:

stats before freezing
FLOPs: 8288; Trainable params: 544
INFO:tensorflow:Froze 2 variables.
INFO:tensorflow:Converted 2 variables to const ops.
stats after freezing
FLOPs: 7200; Trainable params: 0

与 Keras 的结合

from keras import backend as K
from keras.layers import Dense
from keras.models import Sequential
from keras.initializers import Constant
model = Sequential()
model.add(Dense(32, input_dim=4, bias_initializer=Constant(value=0), kernel_initializer=Constant(value=1)))
sess = K.get_session()
graph = sess.graph
stats_graph(graph)

输出为:
FLOPs: 0; Trainable params: 160
Using TensorFlow backend.
2 ops no flops stats due to incomplete shapes.
2 ops no flops stats due to incomplete shapes.
model.summary()
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
dense_1 (Dense) (None, 32) 160
=================================================================
Total params: 160
Trainable params: 160
Non-trainable params: 0
_________________________________________________________________

DL

About

This is Robert Lexis (FengCun Li). To see the world, things dangerous to come to, to see behind walls, to draw closer, to find each other and to feel. That is the purpose of LIFE.
Recent Posts

Static variable in inline
Iterator invalidation rul
Emplace back
Perfect forward

转载于:https://www.cnblogs.com/o-v-o/p/11042066.html

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值