pytorch 计算模型的参数量

参考于:https://blog.csdn.net/jdzwanghao/article/details/84196239

def model_structure(model):
    blank = ' '
    print('-'*90)
    print('|'+' '*11+'weight name'+' '*10+'|' \
            +' '*15+'weight shape'+' '*15+'|' \
            +' '*3+'number'+' '*3+'|')
    print('-'*90)
    num_para = 0
    type_size = 1  ##如果是浮点数就是4
    
    for index, (key, w_variable) in enumerate(model.named_parameters()):
        if len(key) <= 30: 
            key = key + (30-len(key)) * blank
        shape = str(w_variable.shape)
        if len(shape) <= 40:
            shape = shape + (40-len(shape)) * blank
        each_para = 1
        for k in w_variable.shape:
            each_para *= k
        num_para += each_para
        str_num = str(each_para)
        if len(str_num) <= 10:
            str_num = str_num + (10-len(str_num)) * blank
    
        print('| {} | {} | {} |'.format(key, shape, str_num))
    print('-'*90)
    print('The total number of parameters: ' + str(num_para))
    print('The parameters of Model {}: {:4f}M'.format(model._get_name(), num_para * type_size / 1000 / 1000))
    print('-'*90)

结果:

------------------------------------------------------------------------------------------
|           weight name          |               weight shape               |   number   |
------------------------------------------------------------------------------------------
| embed_in.0.weight              | torch.Size([28, 1, 1, 5, 5])             | 700        |
| embed_in.0.bias                | torch.Size([28])                         | 28         |
| downC.0.block1.0.weight        | torch.Size([28, 28, 1, 3, 3])            | 7056       |
| downC.0.block1.1.weight        | torch.Size([28])                         | 28         |
| downC.0.block1.1.bias          | torch.Size([28])                         | 28         |
| downC.0.block2.0.weight        | torch.Size([28, 28, 3, 3, 3])            | 21168      |
| downC.0.block2.1.weight        | torch.Size([28])                         | 28         |
| downC.0.block2.1.bias          | torch.Size([28])                         | 28         |
| downC.0.block2.3.weight        | torch.Size([28, 28, 3, 3, 3])            | 21168      |
| downC.0.block3.weight          | torch.Size([28])                         | 28         |
| downC.0.block3.bias            | torch.Size([28])                         | 28         |
| downC.1.block1.0.weight        | torch.Size([36, 28, 1, 3, 3])            | 9072       |
| downC.1.block1.1.weight        | torch.Size([36])                         | 36         |
| downC.1.block1.1.bias          | torch.Size([36])                         | 36         |
| downC.1.block2.0.weight        | torch.Size([36, 36, 3, 3, 3])            | 34992      |
| downC.1.block2.1.weight        | torch.Size([36])                         | 36         |
| downC.1.block2.1.bias          | torch.Size([36])                         | 36         |
| downC.1.block2.3.weight        | torch.Size([36, 36, 3, 3, 3])            | 34992      |
| downC.1.block3.weight          | torch.Size([36])                         | 36         |
| downC.1.block3.bias            | torch.Size([36])                         | 36         |
| downC.2.block1.0.weight        | torch.Size([48, 36, 1, 3, 3])            | 15552      |
| downC.2.block1.1.weight        | torch.Size([48])                         | 48         |
| downC.2.block1.1.bias          | torch.Size([48])                         | 48         |
| downC.2.block2.0.weight        | torch.Size([48, 48, 3, 3, 3])            | 62208      |
| downC.2.block2.1.weight        | torch.Size([48])                         | 48         |
| downC.2.block2.1.bias          | torch.Size([48])                         | 48         |
| downC.2.block2.3.weight        | torch.Size([48, 48, 3, 3, 3])            | 62208      |
| downC.2.block3.weight          | torch.Size([48])                         | 48         |
| downC.2.block3.bias            | torch.Size([48])                         | 48         |
| center.block1.0.weight         | torch.Size([64, 48, 1, 3, 3])            | 27648      |
| center.block1.1.weight         | torch.Size([64])                         | 64         |
| center.block1.1.bias           | torch.Size([64])                         | 64         |
| center.block2.0.weight         | torch.Size([64, 64, 3, 3, 3])            | 110592     |
| center.block2.1.weight         | torch.Size([64])                         | 64         |
| center.block2.1.bias           | torch.Size([64])                         | 64         |
| center.block2.3.weight         | torch.Size([64, 64, 3, 3, 3])            | 110592     |
| center.block3.weight           | torch.Size([64])                         | 64         |
| center.block3.bias             | torch.Size([64])                         | 64         |
| upS.0.1.weight                 | torch.Size([48, 64, 1, 1, 1])            | 3072       |
| upS.0.1.bias                   | torch.Size([48])                         | 48         |
| upS.1.1.weight                 | torch.Size([36, 48, 1, 1, 1])            | 1728       |
| upS.1.1.bias                   | torch.Size([36])                         | 36         |
| upS.2.1.weight                 | torch.Size([28, 36, 1, 1, 1])            | 1008       |
| upS.2.1.bias                   | torch.Size([28])                         | 28         |
| upC.0.block1.0.weight          | torch.Size([48, 48, 1, 3, 3])            | 20736      |
| upC.0.block1.1.weight          | torch.Size([48])                         | 48         |
| upC.0.block1.1.bias            | torch.Size([48])                         | 48         |
| upC.0.block2.0.weight          | torch.Size([48, 48, 3, 3, 3])            | 62208      |
| upC.0.block2.1.weight          | torch.Size([48])                         | 48         |
| upC.0.block2.1.bias            | torch.Size([48])                         | 48         |
| upC.0.block2.3.weight          | torch.Size([48, 48, 3, 3, 3])            | 62208      |
| upC.0.block3.weight            | torch.Size([48])                         | 48         |
| upC.0.block3.bias              | torch.Size([48])                         | 48         |
| upC.1.block1.0.weight          | torch.Size([36, 36, 1, 3, 3])            | 11664      |
| upC.1.block1.1.weight          | torch.Size([36])                         | 36         |
| upC.1.block1.1.bias            | torch.Size([36])                         | 36         |
| upC.1.block2.0.weight          | torch.Size([36, 36, 3, 3, 3])            | 34992      |
| upC.1.block2.1.weight          | torch.Size([36])                         | 36         |
| upC.1.block2.1.bias            | torch.Size([36])                         | 36         |
| upC.1.block2.3.weight          | torch.Size([36, 36, 3, 3, 3])            | 34992      |
| upC.1.block3.weight            | torch.Size([36])                         | 36         |
| upC.1.block3.bias              | torch.Size([36])                         | 36         |
| upC.2.block1.0.weight          | torch.Size([28, 28, 1, 3, 3])            | 7056       |
| upC.2.block1.1.weight          | torch.Size([28])                         | 28         |
| upC.2.block1.1.bias            | torch.Size([28])                         | 28         |
| upC.2.block2.0.weight          | torch.Size([28, 28, 3, 3, 3])            | 21168      |
| upC.2.block2.1.weight          | torch.Size([28])                         | 28         |
| upC.2.block2.1.bias            | torch.Size([28])                         | 28         |
| upC.2.block2.3.weight          | torch.Size([28, 28, 3, 3, 3])            | 21168      |
| upC.2.block3.weight            | torch.Size([28])                         | 28         |
| upC.2.block3.bias              | torch.Size([28])                         | 28         |
| embed_out.0.weight             | torch.Size([28, 28, 1, 5, 5])            | 19600      |
| embed_out.0.bias               | torch.Size([28])                         | 28         |
| out_affs_2.0.weight            | torch.Size([3, 28, 1, 1, 1])             | 84         |
| out_affs_2.0.bias              | torch.Size([3])                          | 3          |
------------------------------------------------------------------------------------------
The total number of parameters: 821531
The parameters of Model UNet_PNI: 0.821531M
------------------------------------------------------------------------------------------
  • 8
    点赞
  • 17
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

深山里的小白羊

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值