PyTorch Summary 项目教程
1. 项目的目录结构及介绍
PyTorch Summary 项目的目录结构相对简单,主要包含以下文件和目录:
pytorch-summary/
├── LICENSE
├── README.md
├── setup.py
└── torchsummary/
├── __init__.py
└── summary.py
文件和目录介绍
LICENSE
: 项目许可证文件,采用 MIT 许可证。README.md
: 项目说明文档,包含项目的基本介绍、安装方法和使用示例。setup.py
: 用于安装项目的脚本文件。torchsummary/
: 项目的主要代码目录。__init__.py
: 初始化文件,使torchsummary
成为一个 Python 包。summary.py
: 核心文件,包含实现模型摘要功能的代码。
2. 项目的启动文件介绍
项目的启动文件主要是 setup.py
,它负责项目的安装和分发。以下是 setup.py
的基本内容:
from setuptools import setup, find_packages
setup(
name='torchsummary',
version='1.4.5',
description='Model summary in PyTorch similar to `model.summary()` in Keras',
url='https://github.com/sksq96/pytorch-summary',
author='Sksq96',
author_email='sksq96@gmail.com',
license='MIT',
packages=find_packages(),
install_requires=[
'torch',
],
classifiers=[
'License :: OSI Approved :: MIT License',
'Programming Language :: Python :: 3',
'Programming Language :: Python :: 3.6',
'Programming Language :: Python :: 3.7',
'Programming Language :: Python :: 3.8',
'Programming Language :: Python :: 3.9',
],
)
启动文件功能介绍
setup.py
使用setuptools
进行项目的打包和安装。- 定义了项目的名称、版本、描述、URL、作者、许可证等信息。
- 指定了项目的依赖包,如
torch
。 - 使用
find_packages()
自动查找项目中的包。
3. 项目的配置文件介绍
PyTorch Summary 项目没有显式的配置文件,其核心功能主要通过 torchsummary/summary.py
文件实现。以下是 summary.py
文件的部分代码:
import torch
import torch.nn as nn
import torch.nn.functional as F
def summary(model, input_size, batch_size=-1, device='cuda'):
def register_hook(module):
def hook(module, input, output):
class_name = str(module.__class__).split('.')[-1].split("'")[0]
module_idx = len(summary)
m_key = f'{class_name}-{module_idx+1}'
summary[m_key] = {}
summary[m_key]['input_shape'] = list(input[0].size())
summary[m_key]['input_shape'][0] = batch_size
if isinstance(output, (list, tuple)):
summary[m_key]['output_shape'] = [
[-1] + list(o.size())[1:] for o in output
]
else:
summary[m_key]['output_shape'] = list(output.size())
summary[m_key]['output_shape'][0] = batch_size
params = 0
if hasattr(module, 'weight') and hasattr(module.weight, 'size'):
params += torch.prod(torch.LongTensor(list(module.weight.size())))
summary[m_key]['trainable'] = module.weight.requires_grad
if hasattr(module, 'bias') and hasattr(module.bias, 'size'):
params += torch.prod(torch.LongTensor(list(module.bias.size())))
summary[m_key]['nb_params'] = params
if (not isinstance(module, nn.Sequential) and
not isinstance(module, nn.ModuleList) and
not (module == model)):
hooks.append(module.register_forward_hook