#coding=utf-8
from tensorflow.python import pywrap_tensorflow
import os
import tensorflow as tf
flags = tf.app.flags
flags.DEFINE_string('model_path', "20200807195150/", "the export model path")
FLAGS = flags.FLAGS
ckpt = tf.train.get_checkpoint_state(FLAGS.model_path)
saver = tf.train.import_meta_graph(ckpt.model_checkpoint_path +'.meta') # 载入图结构,保存在.meta文件中
total_parameters = 0
#iterating over all variables
for variable in tf.all_variables(): #统计全部的参数
#for variable in tf.trainable_variables():#统计可以训练的参数
local_parameters=1
shape = variable.get_shape() #getting shape of a variable
for i in shape:
local_parameters*=i.value #mutiplying dimension values
total_parameters+=local_parameters
print(total_parameters)
Tensorflow从模型文件中统计可以训练的参数的数目
最新推荐文章于 2022-12-29 11:47:23 发布