参考:TF 2.0 Feature: Flops calculation · Issue #32809 · tensorflow/tensorflow (github.com)
import tensorflow as tf
import numpy as np
def get_flops(model, model_inputs) -> float:
"""
Calculate FLOPS [GFLOPs] for a tf.keras.Model or tf.keras.Sequential model
in inference mode. It uses tf.compat.v1.profiler under the hood.
"""
# if not hasattr(model, "model"):
# raise wandb.Error("self.model must be set before using this method.")
if not isinstance(
model, (tf.keras.models.Sequential, tf.keras.models.Model)
):
raise ValueError(
"Calculating FLOPS is only supported for "
"`tf.keras.Model` and `tf.keras.Sequential` instances."
)
from tensorflow.python.framework.convert_to_constants import (
convert_variables_to_constants_v2_as_graph,
)
# Compute FLOPs for one sample
batch_size = 1
inputs = [
tf.TensorSpec([batch_size] + inp.shape[1:], inp.dtype)
for inp in model_inputs
]
# convert tf.keras model into frozen graph to count FLOPs about operations used at inference
real_model = tf.function(model).get_concrete_function(inputs)
frozen_func, _ = convert_variables_to_constants_v2_as_graph(real_model)
# Calculate FLOPs with tf.profiler
run_meta = tf.compat.v1.RunMetadata()
opts = (
tf.compat.v1.profiler.ProfileOptionBuilder(
tf.compat.v1.profiler.ProfileOptionBuilder().float_operation()
)
.with_empty_output()
.build()
)
flops = tf.compat.v1.profiler.profile(
graph=frozen_func.graph, run_meta=run_meta, cmd="scope", options=opts
)
tf.compat.v1.reset_default_graph()
# convert to GFLOPs
return (flops.total_float_ops / 1e9)/2
#Usage
if __name__ =="__main__":
image_model = tf.keras.applications.EfficientNetB0(include_top=False, weights=None)
x = tf.constant(np.random.randn(1,256,256,3))
print(get_flops(image_model, [x]))
以计算Faster R-CNN 为例:
在加载模型时,随机生成指定张量进行简单计算。
# 计算整个模型的总参数数量
total_params = rpn_params + classifier_params
print("Total Parameters: ", total_params)
x = tf.constant(np.random.randn(1,640,640,3))
x1 = tf.constant(np.random.randn(1,640,640,1024))
y1 = tf.constant(np.random.randn(640,640,4))
# 计算flops
model_rpn_flops = self.get_flops(self.model_rpn, [x])
print("model_rpn", model_rpn_flops)
model_classifier_flops = self.get_flops(self.model_classifier, [x1, y1])
print("model_classifier", model_classifier_flops)
print("Total Flops:", model_rpn_flops + model_classifier_flops)
计算结果: