tensorflow2.0+ 计算模型Flops方法

本文介绍了一个用于计算TensorFlow2.0中Keras模型(如EfficientNetB0和FasterR-CNN)FLOPs(浮点运算次数)的函数。通过示例展示了如何使用该函数评估模型的计算效率,包括RPN和Classifier部分的FLOPs总和。
摘要由CSDN通过智能技术生成

参考:TF 2.0 Feature: Flops calculation · Issue #32809 · tensorflow/tensorflow (github.com)

import tensorflow as tf
import numpy as np

def get_flops(model, model_inputs) -> float:
        """
        Calculate FLOPS [GFLOPs] for a tf.keras.Model or tf.keras.Sequential model
        in inference mode. It uses tf.compat.v1.profiler under the hood.
        """
        # if not hasattr(model, "model"):
        #     raise wandb.Error("self.model must be set before using this method.")

        if not isinstance(
            model, (tf.keras.models.Sequential, tf.keras.models.Model)
        ):
            raise ValueError(
                "Calculating FLOPS is only supported for "
                "`tf.keras.Model` and `tf.keras.Sequential` instances."
            )

        from tensorflow.python.framework.convert_to_constants import (
            convert_variables_to_constants_v2_as_graph,
        )

        # Compute FLOPs for one sample
        batch_size = 1
        inputs = [
            tf.TensorSpec([batch_size] + inp.shape[1:], inp.dtype)
            for inp in model_inputs
        ]

        # convert tf.keras model into frozen graph to count FLOPs about operations used at inference
        real_model = tf.function(model).get_concrete_function(inputs)
        frozen_func, _ = convert_variables_to_constants_v2_as_graph(real_model)

        # Calculate FLOPs with tf.profiler
        run_meta = tf.compat.v1.RunMetadata()
        opts = (
            tf.compat.v1.profiler.ProfileOptionBuilder(
                tf.compat.v1.profiler.ProfileOptionBuilder().float_operation()
            )
            .with_empty_output()
            .build()
        )

        flops = tf.compat.v1.profiler.profile(
            graph=frozen_func.graph, run_meta=run_meta, cmd="scope", options=opts
        )

        tf.compat.v1.reset_default_graph()

        # convert to GFLOPs
        return (flops.total_float_ops / 1e9)/2
    
    
    
#Usage

if __name__ =="__main__":
    image_model = tf.keras.applications.EfficientNetB0(include_top=False, weights=None)
    
    x = tf.constant(np.random.randn(1,256,256,3))
    
    print(get_flops(image_model, [x]))

以计算Faster R-CNN 为例:

在加载模型时,随机生成指定张量进行简单计算。

        # 计算整个模型的总参数数量
        total_params = rpn_params + classifier_params
        print("Total Parameters: ", total_params)
        
        x = tf.constant(np.random.randn(1,640,640,3))
        
        x1 = tf.constant(np.random.randn(1,640,640,1024))
        y1 = tf.constant(np.random.randn(640,640,4))
    
        # 计算flops
        model_rpn_flops = self.get_flops(self.model_rpn, [x])
        print("model_rpn", model_rpn_flops)
        
        model_classifier_flops = self.get_flops(self.model_classifier, [x1, y1])
        print("model_classifier", model_classifier_flops)
        
        print("Total Flops:", model_rpn_flops + model_classifier_flops)

计算结果:

  • 10
    点赞
  • 13
    收藏
    觉得还不错? 一键收藏
  • 2
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值