使用torchkeras打印Pytorch模型结构和基本参数信息

在使用Pytorch构建神经网络模型后,我们需要看一下自己写的模型的网络结构,此时可以使用torchkeras模块中的summary函数实现该功能。以多层感知机为例,首先我们构建网络结构并打印该模型的初步信息,代码如下:

import torch
from torch import nn
from torchkeras import summary


def create_net():
    net = nn.Sequential()
    net.add_module('linear1', nn.Linear(15, 20))
    net.add_module('relu1', nn.ReLU())
    net.add_module('linear2', nn.Linear(20, 1))
    net.add_module('sigmoid', nn.Sigmoid())
    return net

# 创建模型
net = create_net()
# 打印模型的基本信息
print(net)

效果如下:

Sequential(
  (linear1): Linear(in_features=15, out_features=20, bias=True)
  (relu1): ReLU()
  (linear3): Linear(in_features=20, out_features=1, bias=True)
  (sigmoid): Sigmoid()
)

可以看到这种方式输出的模型结构内容不够详细而且形式不够直观,我们想实现像keras打印模型结构的那种效果,可以使用torchkeras模型的summary函数。

torchkeras模型的安装方法为:

pip install torchkeras

调用代码如下:

import torch
from torch import nn
from torchkeras import summary


def create_net():
    net = nn.Sequential()
    net.add_module('linear1', nn.Linear(15, 20))
    net.add_module('relu1', nn.ReLU())
    net.add_module('linear2', nn.Linear(20, 1))
    net.add_module('sigmoid', nn.Sigmoid())
    return net

# 创建模型
net = create_net()
# 使用torchkeras中的summary函数打印模型结构和参数
print(summary(net, input_shape=(15, )))

效果如下:

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Linear-1                   [-1, 20]             320
              ReLU-2                   [-1, 20]               0
            Linear-3                    [-1, 1]              21
           Sigmoid-4                    [-1, 1]               0
================================================================
Total params: 341
Trainable params: 341
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.000057
Forward/backward pass size (MB): 0.000320
Params size (MB): 0.001301
Estimated Total Size (MB): 0.001678
----------------------------------------------------------------
None

可以看到这里模型的层次结构、输出形状和参数基本信息都可以被打印出来,这个模块还是很不错的。此外torchkeras的功能不仅限于此,它是在Pytorch上实现的仿Keras的高层Model结构,可以让我们像使用Keras一样训练Pytorch模型,关于该模块的细节可以参考以下文章:

torchkeras,像Keras一样训练Pytorch模型!

  • 3
    点赞
  • 12
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值