#-*- coding: utf-8 -*-
import os
import sys
import shutil
import struct
from google.protobuf import text_format
import caffe
from caffe.proto import caffe_pb2
LAYER_PARAM = {'Convolution', 'InnerProduct'}
class CalFlop():
def __init__(self, model, deploy):
self.model = model
self.deploy = deploy
self.net = caffe.Net(deploy, model, caffe.TEST)
self.transformer = caffe.io.Transformer({'data': self.net.blobs['data'].data.shape})
self.transformer.set_transpose('data', (2,0,1))
self.netlist = caffe_pb2.NetParameter()
text_format.Merge(open(deploy).read(), self.netlist)
def GetLayerList(self):
LayerList = list()
for layername in self.netlist.layer:
LayerList.append(layername.name)
print layername
return LayerList
def CalFlops(self):
LayerList = self.GetLayerList()
ALL_FLOPS = 0
for Layer in LayerList:
idx = LayerList.index(Layer)
layerparam = self.netlist.layer._values[idx]
if layerparam.type in LAYER_PARAM:
H = self.net._blobs_dict[Layer].height
W = self.net._blobs_dict[Layer].width
blobs = self.net.params[layerparam.name]
batch = blobs[0].num
chns = blobs[0].channels
kh = blobs[0].height
kw = blobs[0].width
FLOPS = batch * chns * kh * kw * H * W
ALL_FLOPS += FLOPS
# print "{} FLOPS is {}".format(Layer, FLOPS)
print "Net FLOPS is {}".format(ALL_FLOPS)
def CalParams(self):
params = 0
for layername in self.netlist.layer:
if layername.type == 'Convolution':
botName = layername.bottom[0]
C = self.net._blobs_dict[botName].channels
chns = layername.convolution_param.num_output
kw = layername.convolution_param.kernel_size[0]
kh = layername.convolution_param.kernel_size[0]
params = params + kw * kh * chns * C
if layername.type == 'InnerProduct':
botName = layername.bottom[0]
H = self.net._blobs_dict[botName].height
W = self.net._blobs_dict[botName].width
C = self.net._blobs_dict[botName].channels
params = params + layername.inner_product_param.num_output * H * W * C
print "params is {}".format(params)
if __name__ =='__main__':
MODEL_FILE = r'/home/ssd/deploy.prototxt'
PRETRAINED = r'/home/ssd/VGG_SSD_300x300.caffemodel'
trans = CalFlop(PRETRAINED, MODEL_FILE)
trans.CalFlops()
trans.CalParams()
写了个统计模型运算量 和 餐数量的脚本,主要用于模型优化后更加直观些。VGG 可用 ,VGG-SSD 可用,没有统计更多,可能有bug ,希望支出。
转载请名出处。