Count to tensorflow model FLOPs and trainable params
Step1:
Convert your tensorflow checkpoint model to PB format
Step2:
Call the function as follows
import tensorflow as tf
from tensorflow.python.framework import graph_util
def import_model(pb_model):
with tf.gfile.GFile(pb_model, "rb") as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
with tf.Graph().as_default() as graph:
tf.import_graph_def(graph_def, name='')
return graph
def count_graph_FLOPs(graph):
flops = tf.profiler.profile(graph, options=tf.profiler.ProfileOptionBuilder.float_operation())
params = tf.profiler.profile(graph, options=tf.profiler.ProfileOptionBuilder.trainable_variables_parameter())
print('FLOPs: {}; Trainable params: {}'.format(flops.total_float_ops, params.total_parameters))
if __name__ == '__main__':
graph = import_model('userModel.pb')
stats_graph(graph)