【Pytorch API笔记3】用torch.numel()来统计网络的参数量

如何统计网络的大小,可以试一试torch.numel()函数
torch.numel()函数,可以计算出单个tensor元素的个数

一、对单个tensor使用,求tensor元素的个数

x = torch.randn((1, 3, 5, 7))
x.numel()
torch.numel()

输出105

二、求整个网络的参数

  n_p = sum(x.numel() for x in model.parameters())  # number parameters

如下示意图,可以计算网络的参数量
一个线性层,输入维度为1,输出维度为100
这个网络有200个参数,可以用x.numel() 巧妙计算出整个网络所需要的参数量

class LinearModel(torch.nn.Module):
    def __init__(self):
        super(LinearModel, self).__init__()
        self.linear = torch.nn.Linear(1, 100) # 输入1、输出的维度都是100
    def forward(self, x):
        out = self.linear(x)
        return out
    
net = LinearModel()
n_p = sum(x.numel() for x in net.parameters())  # number parameters
print(n_p)  ##  ------>输出为200 

  • 1
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值