通过tf.trainable_variables来统计整个网络的参数量
本文列举摘抄了七种方法,但是大同小异,得出的结果也都相同
def count1():
total_parameters = 0
for variable in tf.trainable_variables():
# shape is an array of tf.Dimension
shape = variable.get_shape()
# print(shape)
# print(len(shape))
variable_parameters = 1
for dim in shape:
# print(dim)
variable_parameters *= dim.value
# print(variable_parameters)
total_parameters += variable_parameters
print(total_parameters)
def count2():
print np.sum([np.prod(v.get_shape().as_list()) for v in tf.trainable_variables()])
def get_nb_params_shape(shape):
'''
Computes the total number of params for a given shap.
Works for any number of shapes etc [D,F] or [W,H,C] computes D*F and W*H*C.
'''
nb_params = 1
for dim in shape:
nb_params = nb_params*int(dim)
return nb_params
def count3():
tot_nb_params = 0
for trainable_variable in tf.trainable_variables():
shape = trainable_variable.get_shape() # e.g [D,F] or [W,H,C]
current_nb_params = get_nb_params_shape(shape)
tot_nb_params = tot_nb_params + current_nb_params
print tot_nb_params
def count4():
size = lambda v: reduce(lambda x, y: x * y, v.get_shape().as_list())
n = sum(size(v) for v in tf.trainable_variables())
# print "Model size: %dK" % (n / 1000,)
print n
def count5():
total_parameters = 0
# iterating over all variables
for variable in tf.trainable_variables():
local_parameters = 1
shape = variable.get_shape() # getting shape of a variable
for i in shape:
local_parameters *= i.value # mutiplying dimension values
total_parameters += local_parameters
print(total_parameters)
def count6():
total_parameters = 0
for variable in tf.trainable_variables():
variable_parameters = 1
for dim in variable.get_shape():
variable_parameters *= dim.value
total_parameters += variable_parameters
print("Total number of trainable parameters: %d" % total_parameters)
def count7():
from functools import reduce
from operator import mul
num_params = 0
for variable in tf.trainable_variables():
shape = variable.get_shape()
num_params += reduce(mul, [dim.value for dim in shape], 1)
print num_params
1.How to count total number of trainable parameters in a tensorflow model?
2.What is the best way to count the total number of parameters in a model in TensorFlow?
3.Number of CNN learnable parameters - Python / TensorFlow
4.tensorflow 获取模型所有参数总和数量