Torchsummany打印数据提取

写在前面:Torchsummary的尿性

Torchsummany所提供的summary方法是作pytorch模型各层数据可视化的优秀方法,但是原本的summary方法只能做到调用后打印各层的数据信息,并没有提供任何子函数和官方方法来使使用者提取想要的数据信息。我们只能另辟蹊径。

Torchsummary提取打印数据方式

我们知道,python编程在调用第三方库函数时,可以选择查看该函数的具体构成。对于使用Pycharm集成开发环境的开发者而言,可以直接在引用位置长按ctrl键后点击对应函数名即可查看summary函数具体信息。
在这里插入图片描述
而如果未使用集成开发环境,又或使用其他环境,也可以直接到Torchsummary库保存环境下找到torchsummary.py文件进行查看。如本机所在位置。
在这里插入图片描述
查看后我们可以看到,summary方法的代码模块是比较少的,同时该模块所有的功能实现不依赖特定环境,也就是说我们可以自行复现该函数的功能,同时在复习的功能文件中加入数据返回值,以下是复习并修改后的summany方法代码。

import torch
import torch.nn as nn
from torch.autograd import Variable
from collections import OrderedDict
import numpy as np

def summary(model, input_size, batch_size=-1, device="cuda"):
    # 具体数据存储字典
    return_dic = {}
    def register_hook(module):
        def hook(module, input, output):
            class_name = str(module.__class__).split(".")[-1].split("'")[0]
            module_idx = len(summary)

            m_key = "%s-%i" % (class_name, module_idx + 1)
            summary[m_key] = OrderedDict()
            summary[m_key]["input_shape"] = list(input[0].size())
            summary[m_key]["input_shape"][0] = batch_size
            if isinstance(output, (list, tuple)):
                summary[m_key]["output_shape"] = [
                    [-1] + list(o.size())[1:] for o in output
                ]
            else:
                summary[m_key]["output_shape"] = list(output.size())
                summary[m_key]["output_shape"][0] = batch_size

            params = 0
            if hasattr(module, "weight") and hasattr(module.weight, "size"):
                params += torch.prod(torch.LongTensor(list(module.weight.size())))
                summary[m_key]["trainable"] = module.weight.requires_grad
            if hasattr(module, "bias") and hasattr(module.bias, "size"):
                params += torch.prod(torch.LongTensor(list(module.bias.size())))
            summary[m_key]["nb_params"] = params

        if (
            not isinstance(module, nn.Sequential)
            and not isinstance(module, nn.ModuleList)
            and not (module == model)
        ):
            hooks.append(module.register_forward_hook(hook))

    device = device.lower()
    assert device in [
        "cuda",
        "cpu",
    ], "Input device is not valid, please specify 'cuda' or 'cpu'"

    if device == "cuda" and torch.cuda.is_available():
        dtype = torch.cuda.FloatTensor
    else:
        dtype = torch.FloatTensor

    # multiple inputs to the network
    if isinstance(input_size, tuple):
        input_size = [input_size]

    # batch_size of 2 for batchnorm
    x = [torch.rand(2, *in_size).type(dtype) for in_size in input_size]
    # print(type(x[0]))

    # create properties
    summary = OrderedDict()
    hooks = []

    # register hook
    model.apply(register_hook)

    # make a forward pass
    # print(x.shape)
    model(*x)

    # remove these hooks
    for h in hooks:
        h.remove()
    layer_dic = {}
    print("----------------------------------------------------------------")
    line_new = "{:>20}  {:>25} {:>15}".format("Layer (type)", "Output Shape", "Param #")
    print(line_new)
    print("================================================================")
    total_params = 0
    total_output = 0
    trainable_params = 0
    layer_dic={}
    for layer in summary:
        layer_dic[layer] = [summary[layer]["output_shape"],int(summary[layer]["nb_params"])]
        # input_shape, output_shape, trainable, nb_params
        line_new = "{:>20}  {:>25} {:>15}".format(
            layer,
            str(summary[layer]["output_shape"]),
            "{0:,}".format(summary[layer]["nb_params"]),
        )
        total_params += summary[layer]["nb_params"]
        total_output += np.prod(summary[layer]["output_shape"])
        if "trainable" in summary[layer]:
            if summary[layer]["trainable"] == True:
                trainable_params += summary[layer]["nb_params"]
        print(line_new)
    return_dic['layer'] = layer_dic
    # assume 4 bytes/number (float on cuda).
    total_input_size = abs(np.prod(input_size) * batch_size * 4. / (1024 ** 2.))
    total_output_size = abs(2. * total_output * 4. / (1024 ** 2.))  # x2 for gradients
    total_params_size = abs(total_params.numpy() * 4. / (1024 ** 2.))
    total_size = total_params_size + total_output_size + total_input_size

    print("================================================================")
    print("Total params: {0:,}".format(total_params))
    return_dic['Total params'] = total_params
    print("Trainable params: {0:,}".format(trainable_params))
    return_dic['Trainable params'] = trainable_params
    print("Non-trainable params: {0:,}".format(total_params - trainable_params))
    return_dic['Non-trainable params'] = total_params - trainable_params
    print("----------------------------------------------------------------")
    print("Input size (MB): %0.2f" % total_input_size)
    return_dic['Input size'] = total_input_size
    print("Forward/backward pass size (MB): %0.2f" % total_output_size)
    return_dic['Forward/backward pass size (MB)'] = total_output_size
    print("Params size (MB): %0.2f" % total_params_size)
    return_dic['Params size (MB)'] = total_params_size
    print("Estimated Total Size (MB): %0.2f" % total_size)
    return_dic['Estimated Total Size (MB)'] = total_size
    print("----------------------------------------------------------------")
    return return_dic

该方法返回一个包含原打印表格中所有数据元素的数据字典,使用者可以根据字典语法从中获取想要的数据。
该代码块可以直接在torchsummany.py中更新,覆盖掉原先的代码块,但并不建议这么做(因为你不能确定以后会不会用到只作打印的情况)。读者可以尝试在自己的项目中加入一个新的功能文件,用以存储该代码块。而后在需要使用到打印数据提取功能的项目文件中引入。如我自己所建的就是myselfdata_scan.py,可以被同级文件所调用。
在这里插入图片描述

希望本篇文章可以帮助到你。

  • 1
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值