使用说明:
1)修改代码段中python的路径。
2)运行方式:python cal_params.py /data1/......./deploy.prototxt
#how to use?
#python python_file.py deploy.prototxt
import sys
#sys.path.insert(0, "/data_1/caffe-path/caffe-refinedet-shufflenet-nocudnn/python")
sys.path.insert(0, "/data_1/caffe-path/RefineDet-master/python")
import caffe
caffe.set_mode_cpu()
import numpy as np
from numpy import prod, sum
from pprint import pprint
def print_net_parameters_flops (deploy_file):
print ("Net: " + deploy_file)
net = caffe.Net(deploy_file, caffe.TEST)
flops = 0
typenames = ['Convolution', 'DepthwiseConvolution', 'InnerProduct','ConvolutionDepthwise']
print ("Layer-wise parameters: ")
print ('layer name'.ljust(20), 'Filter Shape'.ljust(20), \
'Output Size'.ljust(20), 'Layer Type'.ljust(20), 'Flops'.ljust(20), 'params'.ljust(20))
for layer_name, blob in net.blobs.items():
if layer_name not in net.layer_dict:
continue
if net.layer_dict[layer_name].type in typenames:
#calculator flops
cur_flops = 0.0
if net.layer_dict[layer_name].type in typenames[:2]:
cur_flops = (np.product(net.params[layer_name][0].data.shape) * \
blob.data.shape[-1] * blob.data.shape[-2])
else:
cur_flops = np.product(net.params[layer_name][0].data.shape)
#calculator params
params_num = 1
for x in net.params[layer_name][0].data.shape:
params_num = x * params_num
#print result
print(layer_name.ljust(20), #layer name
str(net.params[layer_name][0].data.shape).ljust(20), #filter Shape
str(blob.data.shape).ljust(20), #output size
net.layer_dict[layer_name].type.ljust(20), #layer type
str(cur_flops).ljust(20), #flops
str(params_num).ljust(20)) #params
# InnerProduct
if len(blob.data.shape) == 2:
flops += prod(net.params[layer_name][0].data.shape)
else:
flops += prod(net.params[layer_name][0].data.shape) * blob.data.shape[2] * blob.data.shape[3]
print ('layers num: ' + str(len(net.params.items())))
print ("Total number of parameters: " + str(sum([prod(v[0].data.shape) for k, v in net.params.items()])))
print ("Total number of flops: " + str(flops))
if __name__ == '__main__':
if len(sys.argv) != 2:
print ('Usage:')
print ('python calc_params.py deploy.prototxt')
exit()
deploy_file = sys.argv[1]
print_net_parameters_flops(deploy_file)