优雅打印PyTorch模型的总参数量、每层名称、尺寸与精度

本文介绍了一个用于PyTorch模型的函数,它能打印模型的总参数量、可训练参数量、非训练参数量,以及每一层的名称、尺寸、精度、参数数量和训练状态。以ResNet50为例展示了如何使用该函数获取模型详细信息。
摘要由CSDN通过智能技术生成

在使用PyTorch进行深度学习训练时,经常需要打印模型结构,查看每一层的名称、尺寸、精度、参数量等等,并在最后计算总参数量。

以下的打印函数包含了打印:

  1. 模型总参数量
  2. 模型被冻结的参数量
  3. 模型未被总结的参数量
  4. 每层的名称
  5. 每层的类别名
  6. 每层的shape
  7. 每层的精度
  8. 每层的参数量
  9. 每层是否被冻结

打印函数

def get_pytorch_model_info(model: torch.nn.Module) -> (dict, list):
    """
    输入一个PyTorch Model对象,返回模型的总参数量(格式化为易读格式)以及每一层的名称、尺寸、精度、参数量、是否可训练和层的类别。

    :param model: PyTorch Model
    :return: (总参数量信息, 参数列表[包括每层的名称、尺寸、数据类型、参数量、是否可训练和层的类别])
    """
    params_list = []
    total_params = 0
    total_params_non_trainable = 0

    for name, param in model.named_parameters():
        # 获取参数所属层的名称
        layer_name = name.split('.')[0]
        # 获取层的对象
        layer = dict(model.named_modules())[layer_name]
        # 获取层的类名
        layer_class = layer.__class__.__name__

        params_count = param.numel()
        trainable = param.requires_grad
        params_list.append({
            'tensor': name,
            'layer_class': layer_class,
            'shape': str(list(param.size())),
            'precision': str(param.dtype).split('.')[-1],
            'params_count': str(params_count),
            'trainable': str(trainable),
        })
        total_params += params_count
        if not trainable:
            total_params_non_trainable += params_count

    total_params_trainable = total_params - total_params_non_trainable

    total_params_info = {
        'total_params': format_size(total_params),
        'total_params_trainable': format_size(total_params_trainable),
        'total_params_non_trainable': format_size(total_params_non_trainable)
    }

    return total_params_info, params_list

测试样例

我们用resnet50来进行测试:

import torchvision
model = torchvision.models.resnet50(weights=None)

完整测试代码:

import torch
import torchvision


def format_size(size):
    # 对总参数量做格式优化
    K, M, B = 1e3, 1e6, 1e9
    if size == 0:
        return '0'
    elif size < M:
        return f"{size / K:.1f}K"
    elif size < B:
        return f"{size / M:.1f}M"
    else:
        return f"{size / B:.1f}B"


def get_pytorch_model_info(model: torch.nn.Module) -> (dict, list):
    """
    输入一个PyTorch Model对象,返回模型的总参数量(格式化为易读格式)以及每一层的名称、尺寸、精度、参数量、是否可训练和层的类别。

    :param model: PyTorch Model
    :return: (总参数量信息, 参数列表[包括每层的名称、尺寸、数据类型、参数量、是否可训练和层的类别])
    """
    params_list = []
    total_params = 0
    total_params_non_trainable = 0

    for name, param in model.named_parameters():
        # 获取参数所属层的名称
        layer_name = name.split('.')[0]
        # 获取层的对象
        layer = dict(model.named_modules())[layer_name]
        # 获取层的类名
        layer_class = layer.__class__.__name__

        params_count = param.numel()
        trainable = param.requires_grad
        params_list.append({
            'tensor': name,
            'layer_class': layer_class,
            'shape': str(list(param.size())),
            'precision': str(param.dtype).split('.')[-1],
            'params_count': str(params_count),
            'trainable': str(trainable),
        })
        total_params += params_count
        if not trainable:
            total_params_non_trainable += params_count

    total_params_trainable = total_params - total_params_non_trainable

    total_params_info = {
        'total_params': format_size(total_params),
        'total_params_trainable': format_size(total_params_trainable),
        'total_params_non_trainable': format_size(total_params_non_trainable)
    }

    return total_params_info, params_list

运行结果:
在这里插入图片描述

  • 13
    点赞
  • 11
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值