pytorch 获取模型参数_Pytorch获取模型参数情况的方法

分享人工智能技术干货,专注深度学习与计算机视觉领域!

相较于Tensorflow,Pytorch一开始就是以动态图构建神经网络图的,其获取模型参数的方法也比较容易,既可以根据其内建接口自己写代码获取模型参数情况,也可以借助第三方库来获取模型参数情况,下面,就让我们一起来了解Pytorch获取模型参数情况的这两种方法!

Pytorch依据其内建接口自己写代码获取模型参数情况,我们主要是借助该框架提供的模型parameters()接口并获取对应参数的size来实现的,对于该参数是否属于可训练参数,那么我们可以依据Pytorch提供的requires_grad标志位来进行判断,具体方法如下代码所示:

# 定义总参数量、可训练参数量及非可训练参数量变量

Total_params = 0

Trainable_params = 0

NonTrainable_params = 0

# 遍历model.parameters()返回的全局参数列表

for param in model.parameters():

mulValue = np.prod(param.size()) # 使用numpy prod接口计算参数数组所有元素之积

Total_params += mulValue # 总参数量

if param.requires_grad:

Trainable_params += mulValue # 可训练参数量

else:

NonTrainable_params += mulValue # 非可训练参数量

print(f'Total p

  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值