TensorFlow打印网络参数的个数

注意区分打印网络参数的个数和打印网络参数(权重和偏置)的个数

在TensorFlow 1.0 中,可以通过使用tf.trainable_variables()获取模型的所有可训练参数(即权重和偏置),并使用sess.run()在会话中运行这些变量来打印它们的值。

打印网络参数(权重和偏置)

import tensorflow as tf

# 构建模型

# 创建会话
with tf.Session() as sess:
    # 初始化所有变量
    sess.run(tf.global_variables_initializer())
    
    # 获取所有可训练的变量
    trainable_vars = tf.trainable_variables()
    
    # 打印每个变量的名称和值
    for var in trainable_vars:
        print(var.name)
        print(sess.run(var))

打印出网络参数的个数,需要获取每个可训练参数的形状,然后计算它们的乘积来得到每个参数的元素个数。最后,将所有参数的元素个数相加即可得到网络参数的总个数。

import tensorflow as tf
import numpy as np

# 构建模型

# 创建会话
with tf.Session() as sess:
    # 初始化所有变量
    sess.run(tf.global_variables_initializer())
    
    # 获取所有可训练的变量
    trainable_vars = tf.trainable_variables()
    
    # 计算所有参数的总个数
    total_parameters = 0
    for variable in trainable_vars:
        # 获取变量的形状,例如[5, 5, 1, 32]表示一个5x5的32通道卷积核
        shape = variable.get_shape()
        
        # 计算当前变量的参数个数,为形状的各维大小的乘积
        variable_parametes = 1
        for dim in shape:
            variable_parametes *= dim.value
        
        # 将当前变量的参数个数加到总个数上
        total_parameters += variable_parametes
    
    print("Total number of parameters in the network: {}".format(total_parameters))
  • 1
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值