flops统计工具

在现有网络中有些层为被纳入开源统计工具,所以实现了一个简易的统计工具,方便拓展。

def calc_flops(model, input):
    
    

    def conv_hook(self, input, output):
        batch_size, input_channels, input_height, input_width = input[0].size()
        output_channels, output_height, output_width = output[0].size()
 
        kernel_ops = self.kernel_size[0] * self.kernel_size[1] * (self.in_channels / self.groups) * (
            2 if multiply_adds else 1)
        bias_ops = 1 if self.bias is not None else 0
 
        params = output_channels * (kernel_ops + bias_ops)
        flops = batch_size * params * output_height * output_width
 
        list_conv.append(flops)
    
    def linear_hook(self, input, output):
        batch_size = input[0].size(0) if input[0].dim() == 2 else 1
        num_steps = input[0].size(0)
        weight_ops = self.weight.nelement() * (2 if multiply_adds else 1)
        bias_ops = self.bias.nelement() if self.bias is not None else 0
 
        flops = batch_size * (weight_ops + bias_ops)
        flops *= num_steps
        list_linear.append(flops)

    def fsmn_hook(self, input, output):
        batch_size = input[0].size(0) if input[0].dim() == 2 else 1
 
        weight_ops = self.filter.nelement() * (2 if multiply_adds else 1)
        num_steps = input[0].size(0)
        flops =num_steps*weight_ops
        flops *= batch_size
        list_fsmn.append(flops)
    
    def gru_cell(input_size, hidden_size, bias=True):
        total_ops = 0
        # r = \sigma(W_{ir} x + b_{ir} + W_{hr} h + b_{hr}) \\
        # z = \sigma(W_{iz} x + b_{iz} + W_{hz} h + b_{hz}) \\
        state_ops = (hidden_size + input_size) * hidden_size + hidden_size
        if bias:
            state_ops += hidden_size * 2
        total_ops += state_ops * 2

        # n = \tanh(W_{in} x + b_{in} + r * (W_{hn} h + b_{hn})) \\
        total_ops += (hidden_size + input_size) * hidden_size + hidden_size
        if bias:
            total_ops += hidden_size * 2
        # r hadamard : r * (~)
        total_ops += hidden_size

        # h' = (1 - z) * n + z * h
        # hadamard hadamard add
        total_ops += hidden_size * 3

        return total_ops

    def gru_hook(self, input, output):

        batch_size = input[0].size(0) if input[0].dim() == 2 else 1
        if self.batch_first:
            batch_size = input[0].size(0)
            num_steps = input[0].size(1)
        else:
            batch_size = input[0].size(1)
            num_steps = input[0].size(0)
        total_ops = 0
        bias = self.bias
        input_size = self.input_size
        hidden_size = self.hidden_size
        num_layers = self.num_layers
        total_ops = 0
        total_ops += gru_cell(input_size, hidden_size, bias)
        for i in range(num_layers - 1):
            total_ops += gru_cell(hidden_size, hidden_size, bias)
        total_ops *= batch_size
        total_ops *= num_steps

        list_lstm.append(total_ops)

    def lstm_cell(input_size,hidden_size,bias):
        total_ops = 0
        state_ops = (input_size + hidden_size) * hidden_size + hidden_size
        if bias:
            state_ops += hidden_size * 2
        total_ops += state_ops * 4
        total_ops += hidden_size * 3
        total_ops += hidden_size
        return total_ops

    def lstm_hook(self, input, output):

        batch_size = input[0].size(0) if input[0].dim() == 2 else 1
        if self.batch_first:
            batch_size = input[0].size(0)
            num_steps = input[0].size(1)
        else:
            batch_size = input[0].size(1)
            num_steps = input[0].size(0)
        total_ops = 0
        bias = self.bias
        input_size = self.input_size
        hidden_size = self.hidden_size
        num_layers = self.num_layers
        total_ops = 0
        total_ops += lstm_cell(input_size, hidden_size, bias)
        for i in range(num_layers - 1):
            total_ops += lstm_cell(hidden_size, hidden_size, bias)
        total_ops *= batch_size
        total_ops *= num_steps

        list_lstm.append(total_ops)
 
    def bn_hook(self, input, output):
        list_bn.append(input[0].nelement())
 
    def relu_hook(self, input, output):
        list_relu.append(input[0].nelement())
 
    def pooling_hook(self, input, output):
        batch_size, input_channels, input_height, input_width = input[0].size()
        output_channels, output_height, output_width = output[0].size()
 
        kernel_ops = self.kernel_size * self.kernel_size
        bias_ops = 0
        params = output_channels * (kernel_ops + bias_ops)
        flops = batch_size * params * output_height * output_width
 
        list_pooling.append(flops)
 
    def foo(net):
        childrens = list(net.children())
        if not childrens:
            print(net)
            if isinstance(net, torch.nn.Conv2d) or isinstance(net, torch.nn.ConvTranspose2d):
                net.register_forward_hook(conv_hook)
                #print('conv_hook_ready')
            if isinstance(net, torch.nn.Linear):
                net.register_forward_hook(linear_hook)
                #print('linear_hook_ready')
            if isinstance(net, torch.nn.BatchNorm2d):
                net.register_forward_hook(bn_hook)
                #print('batch_norm_hook_ready')
            if isinstance(net, torch.nn.ReLU)or isinstance(net, torch.nn.PReLU):
                net.register_forward_hook(relu_hook)
                #print('relu_hook_ready')
            if isinstance(net, torch.nn.MaxPool2d) or isinstance(net, torch.nn.AvgPool2d):
                net.register_forward_hook(pooling_hook)
                #print('pooling_hook_ready')
            if isinstance(net, torch.nn.LSTM):
                net.register_forward_hook(lstm_hook)
                #print('lstm_hook_ready')
            if isinstance(net, torch.nn.GRU):
                net.register_forward_hook(gru_hook)
                
            if isinstance(net, FSMNZQ):
                net.register_forward_hook(fsmn_hook)
                #print('fsmn_hook_ready')
            return
        for c in childrens:
            foo(c)
 
    multiply_adds = False
    list_conv, list_bn, list_relu, list_linear, list_pooling, list_lstm, list_fsmn = [], [], [], [], [], [], []
    foo(model)

    _ = model(input)
    
    total_flops = (sum(list_conv) + sum(list_linear) + sum(list_bn) + sum(list_relu) + sum(list_pooling)+sum(list_lstm)+sum(list_fsmn))
    fsmn_flops = (sum(list_fsmn) + sum(list_linear) )
    lstm_flops = sum(list_lstm)

    model_parameters = filter(lambda p: p.requires_grad, model.parameters())
    params = sum([np.prod(p.size()) for p in model_parameters])
    print('The network has {} params.'.format(params))

    print( total_flops,fsmn_flops,lstm_flops)
    print('  + Number of FLOPs: %.2f M' % (total_flops/1000**2))
    return total_flops

当前工具支持:fsmn gru lstm cnn linear relu prerelu

主要依赖:net.register_forward_hook hook机制获取输入输出维度

开源工具:thop 调用如下

from thop import profile
from thop import clever_format
flops, params = profile(model, inputs=(batch,), verbose=False)
flops, params = clever_format([flops, params], "%.3f")

 

ref:GitHub - Lyken17/pytorch-OpCounter: Count the MACs / FLOPs of your PyTorch model.

 

  • 0
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值