获得tensorflow模型的参数量

tensorflow模型训练好之后,通常会保存为.ckpt文件,有时我们想了解一下模型保存了多少参数,这固然可以手动计算,但是速度太慢,这里我写了一个程序可以直接获得模型的参数量,下面是源代码:
转自: https://blog.csdn.net/chnguoshiwushuang/article/details/81588794

from tensorflow.python import pywrap_tensorflow
import os
import numpy as np
model_dir = "models_pretrained/"
checkpoint_path = os.path.join(model_dir, "model.ckpt-82798")
reader = pywrap_tensorflow.NewCheckpointReader(checkpoint_path)
var_to_shape_map = reader.get_variable_to_shape_map()
total_parameters = 0
for key in var_to_shape_map:#list the keys of the model
    # print(key)
    # print(reader.get_tensor(key))
    shape = np.shape(reader.get_tensor(key))  #get the shape of the tensor in the model
    shape = list(shape)
    # print(shape)
    # print(len(shape))
    variable_parameters = 1
    for dim in shape:
        # print(dim)
        variable_parameters *= dim
    # print(variable_parameters)
    total_parameters += variable_parameters

print(total_parameters)
 
 
    • 2
      点赞
    • 2
      收藏
      觉得还不错? 一键收藏
    • 0
      评论

    “相关推荐”对你有帮助么?

    • 非常没帮助
    • 没帮助
    • 一般
    • 有帮助
    • 非常有帮助
    提交
    评论
    添加红包

    请填写红包祝福语或标题

    红包个数最小为10个

    红包金额最低5元

    当前余额3.43前往充值 >
    需支付:10.00
    成就一亿技术人!
    领取后你会自动成为博主和红包主的粉丝 规则
    hope_wisdom
    发出的红包
    实付
    使用余额支付
    点击重新获取
    扫码支付
    钱包余额 0

    抵扣说明:

    1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
    2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

    余额充值