在现有网络中有些层为被纳入开源统计工具,所以实现了一个简易的统计工具,方便拓展。
def calc_flops(model, input):
def conv_hook(self, input, output):
batch_size, input_channels, input_height, input_width = input[0].size()
output_channels, output_height, output_width = output[0].size()
kernel_ops = self.kernel_size[0] * self.kernel_size[1] * (self.in_channels / self.groups) * (
2 if multiply_adds else 1)
bias_ops = 1 if self.bias is not None else 0
params = output_channels * (kernel_ops + bias_ops)
flops = batch_size * params * output_height * output_width
list_conv.append(flops)
def linear_hook(self, input, output):
batch_size = input[0].size(0) if input[0].dim() == 2 else 1
num_steps = input[0].size(0)
weight_ops = self.weight.nelement() * (2 if multiply_adds else 1)
bias_ops = self.bias.nelement() if self.bias is not None else 0
flops = batch_size * (weight_ops + bias_ops)
flops *= num_steps
list_linear.append(flops)
def fsmn_hook(self, input, output):
batch_size = input[0].size(0) if input[0].dim() == 2 else 1
weight_ops = self.filter.nelement() * (2 if multiply_adds else 1)
num_steps = input[0].size(0)
flops =num_steps*weight_ops
flops *= batch_size
list_fsmn.append(flops)
def gru_cell(input_size, hidden_size, bias=True):
total_ops = 0
# r = \sigma(W_{ir} x + b_{ir} + W_{hr} h + b_{hr}) \\
# z = \sigma(W_{iz} x + b_{iz} + W_{hz} h + b_{hz}) \\
state_ops = (hidden_size + input_size) * hidden_size + hidden_size
if bias:
state_ops += hidden_size * 2
total_ops += state_ops * 2
# n = \tanh(W_{in} x + b_{in} + r * (W_{hn} h + b_{hn})) \\
total_ops += (hidden_size + input_size) * hidden_size + hidden_size
if bias:
total_ops += hidden_size * 2
# r hadamard : r * (~)
total_ops += hidden_size
# h' = (1 - z) * n + z * h
# hadamard hadamard add
total_ops += hidden_size * 3
return total_ops
def gru_hook(self, input, output):
batch_size = input[0].size(0) if input[0].dim() == 2 else 1
if self.batch_first:
batch_size = input[0].size(0)
num_steps = input[0].size(1)
else:
batch_size = input[0].size(1)
num_steps = input[0].size(0)
total_ops = 0
bias = self.bias
input_size = self.input_size
hidden_size = self.hidden_size
num_layers = self.num_layers
total_ops = 0
total_ops += gru_cell(input_size, hidden_size, bias)
for i in range(num_layers - 1):
total_ops += gru_cell(hidden_size, hidden_size, bias)
total_ops *= batch_size
total_ops *= num_steps
list_lstm.append(total_ops)
def lstm_cell(input_size,hidden_size,bias):
total_ops = 0
state_ops = (input_size + hidden_size) * hidden_size + hidden_size
if bias:
state_ops += hidden_size * 2
total_ops += state_ops * 4
total_ops += hidden_size * 3
total_ops += hidden_size
return total_ops
def lstm_hook(self, input, output):
batch_size = input[0].size(0) if input[0].dim() == 2 else 1
if self.batch_first:
batch_size = input[0].size(0)
num_steps = input[0].size(1)
else:
batch_size = input[0].size(1)
num_steps = input[0].size(0)
total_ops = 0
bias = self.bias
input_size = self.input_size
hidden_size = self.hidden_size
num_layers = self.num_layers
total_ops = 0
total_ops += lstm_cell(input_size, hidden_size, bias)
for i in range(num_layers - 1):
total_ops += lstm_cell(hidden_size, hidden_size, bias)
total_ops *= batch_size
total_ops *= num_steps
list_lstm.append(total_ops)
def bn_hook(self, input, output):
list_bn.append(input[0].nelement())
def relu_hook(self, input, output):
list_relu.append(input[0].nelement())
def pooling_hook(self, input, output):
batch_size, input_channels, input_height, input_width = input[0].size()
output_channels, output_height, output_width = output[0].size()
kernel_ops = self.kernel_size * self.kernel_size
bias_ops = 0
params = output_channels * (kernel_ops + bias_ops)
flops = batch_size * params * output_height * output_width
list_pooling.append(flops)
def foo(net):
childrens = list(net.children())
if not childrens:
print(net)
if isinstance(net, torch.nn.Conv2d) or isinstance(net, torch.nn.ConvTranspose2d):
net.register_forward_hook(conv_hook)
#print('conv_hook_ready')
if isinstance(net, torch.nn.Linear):
net.register_forward_hook(linear_hook)
#print('linear_hook_ready')
if isinstance(net, torch.nn.BatchNorm2d):
net.register_forward_hook(bn_hook)
#print('batch_norm_hook_ready')
if isinstance(net, torch.nn.ReLU)or isinstance(net, torch.nn.PReLU):
net.register_forward_hook(relu_hook)
#print('relu_hook_ready')
if isinstance(net, torch.nn.MaxPool2d) or isinstance(net, torch.nn.AvgPool2d):
net.register_forward_hook(pooling_hook)
#print('pooling_hook_ready')
if isinstance(net, torch.nn.LSTM):
net.register_forward_hook(lstm_hook)
#print('lstm_hook_ready')
if isinstance(net, torch.nn.GRU):
net.register_forward_hook(gru_hook)
if isinstance(net, FSMNZQ):
net.register_forward_hook(fsmn_hook)
#print('fsmn_hook_ready')
return
for c in childrens:
foo(c)
multiply_adds = False
list_conv, list_bn, list_relu, list_linear, list_pooling, list_lstm, list_fsmn = [], [], [], [], [], [], []
foo(model)
_ = model(input)
total_flops = (sum(list_conv) + sum(list_linear) + sum(list_bn) + sum(list_relu) + sum(list_pooling)+sum(list_lstm)+sum(list_fsmn))
fsmn_flops = (sum(list_fsmn) + sum(list_linear) )
lstm_flops = sum(list_lstm)
model_parameters = filter(lambda p: p.requires_grad, model.parameters())
params = sum([np.prod(p.size()) for p in model_parameters])
print('The network has {} params.'.format(params))
print( total_flops,fsmn_flops,lstm_flops)
print(' + Number of FLOPs: %.2f M' % (total_flops/1000**2))
return total_flops
当前工具支持:fsmn gru lstm cnn linear relu prerelu
主要依赖:net.register_forward_hook hook机制获取输入输出维度
开源工具:thop 调用如下
from thop import profile
from thop import clever_format
flops, params = profile(model, inputs=(batch,), verbose=False)
flops, params = clever_format([flops, params], "%.3f")
ref:GitHub - Lyken17/pytorch-OpCounter: Count the MACs / FLOPs of your PyTorch model.