模型可视化库torch-summary使用

安装

pip下安装:

pip install torch-summary

conda下安装:

conda install torch-summary

 注:torchsummary与torch-summary是两个不同库!后者是前者的升级版,添加更多功能且解决了部分bug,因此推荐使用torch-summary!

使用

通过nn.Module构建一个模型(以一个简单的LSTM为例):

import torch.nn as nn
class LSTM(nn.Module):
    def __init__(self):
        super(LSTM, self).__init__()
        self.lstm = nn.LSTM(input_size=1, hidden_size=16, batch_first=True)
        self.mlp = nn.Sequential(
            nn.Linear(16, 8),
            nn.ReLU(),
            nn.Linear(8, 1)
        )

    def forward(self, x):
        x = self.lstm(x)[0][:, -1, :]
        res = self.mlp(x)
        return res

if __name__ == '__main__':
    model = LSTM()
    input = torch.randn(8, 32, 1)
    output = model(input)

使用torch-summary可视化网络: 

import torchsummary

model = LSTM()
torchsummary.summary(model, input_size=(32, 1), batch_size=8)

结果输出:

        图中可看出模型的层次结构以及各层的参数统计:包括LSTM和Sequential层,分别包含1216、145个参数,其中Sequential层的两个Linear层分别包含136、9个参数。

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值