PyTorch Summary 项目教程

PyTorch Summary 项目教程

pytorch-summarypytorch-summary - 一个PyTorch库,提供类似于Keras中model.summary()的功能,用于可视化模型结构和参数信息。项目地址:https://gitcode.com/gh_mirrors/py/pytorch-summary

1. 项目的目录结构及介绍

PyTorch Summary 项目的目录结构相对简单,主要包含以下文件和目录:

pytorch-summary/
├── LICENSE
├── README.md
├── setup.py
└── torchsummary/
    ├── __init__.py
    └── summary.py

文件和目录介绍

  • LICENSE: 项目许可证文件,采用 MIT 许可证。
  • README.md: 项目说明文档,包含项目的基本介绍、安装方法和使用示例。
  • setup.py: 用于安装项目的脚本文件。
  • torchsummary/: 项目的主要代码目录。
    • __init__.py: 初始化文件,使 torchsummary 成为一个 Python 包。
    • summary.py: 核心文件,包含实现模型摘要功能的代码。

2. 项目的启动文件介绍

项目的启动文件主要是 setup.py,它负责项目的安装和分发。以下是 setup.py 的基本内容:

from setuptools import setup, find_packages

setup(
    name='torchsummary',
    version='1.4.5',
    description='Model summary in PyTorch similar to `model.summary()` in Keras',
    url='https://github.com/sksq96/pytorch-summary',
    author='Sksq96',
    author_email='sksq96@gmail.com',
    license='MIT',
    packages=find_packages(),
    install_requires=[
        'torch',
    ],
    classifiers=[
        'License :: OSI Approved :: MIT License',
        'Programming Language :: Python :: 3',
        'Programming Language :: Python :: 3.6',
        'Programming Language :: Python :: 3.7',
        'Programming Language :: Python :: 3.8',
        'Programming Language :: Python :: 3.9',
    ],
)

启动文件功能介绍

  • setup.py 使用 setuptools 进行项目的打包和安装。
  • 定义了项目的名称、版本、描述、URL、作者、许可证等信息。
  • 指定了项目的依赖包,如 torch
  • 使用 find_packages() 自动查找项目中的包。

3. 项目的配置文件介绍

PyTorch Summary 项目没有显式的配置文件,其核心功能主要通过 torchsummary/summary.py 文件实现。以下是 summary.py 文件的部分代码:

import torch
import torch.nn as nn
import torch.nn.functional as F

def summary(model, input_size, batch_size=-1, device='cuda'):
    def register_hook(module):
        def hook(module, input, output):
            class_name = str(module.__class__).split('.')[-1].split("'")[0]
            module_idx = len(summary)

            m_key = f'{class_name}-{module_idx+1}'
            summary[m_key] = {}
            summary[m_key]['input_shape'] = list(input[0].size())
            summary[m_key]['input_shape'][0] = batch_size
            if isinstance(output, (list, tuple)):
                summary[m_key]['output_shape'] = [
                    [-1] + list(o.size())[1:] for o in output
                ]
            else:
                summary[m_key]['output_shape'] = list(output.size())
                summary[m_key]['output_shape'][0] = batch_size

            params = 0
            if hasattr(module, 'weight') and hasattr(module.weight, 'size'):
                params += torch.prod(torch.LongTensor(list(module.weight.size())))
                summary[m_key]['trainable'] = module.weight.requires_grad
            if hasattr(module, 'bias') and hasattr(module.bias, 'size'):
                params += torch.prod(torch.LongTensor(list(module.bias.size())))
            summary[m_key]['nb_params'] = params

        if (not isinstance(module, nn.Sequential) and
           not isinstance(module, nn.ModuleList) and
           not (module == model)):
            hooks.append(module.register_forward_hook

pytorch-summarypytorch-summary - 一个PyTorch库,提供类似于Keras中model.summary()的功能,用于可视化模型结构和参数信息。项目地址:https://gitcode.com/gh_mirrors/py/pytorch-summary

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

凤尚柏Louis

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值