Tensorflow从模型文件中统计可以训练的参数的数目

#coding=utf-8
from tensorflow.python import pywrap_tensorflow

import os
import tensorflow as tf




flags = tf.app.flags
flags.DEFINE_string('model_path', "20200807195150/", "the export model path")


FLAGS = flags.FLAGS

ckpt = tf.train.get_checkpoint_state(FLAGS.model_path)
saver = tf.train.import_meta_graph(ckpt.model_checkpoint_path +'.meta')   # 载入图结构,保存在.meta文件中


total_parameters = 0
#iterating over all variables
for variable in tf.all_variables(): #统计全部的参数
#for variable in tf.trainable_variables():#统计可以训练的参数
  local_parameters=1
  shape = variable.get_shape()  #getting shape of a variable
  for i in shape:
    local_parameters*=i.value  #mutiplying dimension values
    total_parameters+=local_parameters
print(total_parameters)
©️2020 CSDN 皮肤主题: 大白 设计师:CSDN官方博客 返回首页