201203-pytorch中model模型显示及维度信息(直接打印、summary、make_dot)

引言

在学习Pytorch中,为了更好理解网络结构,需要结合mdoel的图片结构和维度信息才能更好理解。keras中model.summary和plot_model工具就十分好用。在pytorch中,经过多方搜索,下列三种方式有助于自己理解,在此mark一下。其中summary要能知道模型的输入shape,可根据源代码和报错中提示进行尝试。

import torch
from torchviz import make_dot
from torch.autograd import Variable
from torchsummary import summary

model.netG  # 直接打印
summary(model.netG, (3,256,256))  # 每层输出shape
xtmp = Variable(torch.randn(1,3,256,256))
ytmp = model.netG(xtmp)
make_dot(ytmp, params=dict(model.netG.named_parameters())).render('tmp', view=True)  # render用于保存为图片

model.netD
summary(model.netD, (6, 256, 256))
xtmp2 = Variable(torch.randn(1,6,256,256))
ytmp2 = model.netD(xtmp2)
make_dot(ytmp2, params=dict(model.netD.named_parameters())).render('tmp2', view=True)

类似的make_dot,似乎更简洁些, From言有三

import torch
from torch.autograd import Variable
from visualize import  make_dot
x = Variable(torch.randn(1,3,48,48))
model = simpleconv3()
y = model(x)
g = make_dot(y)
g.view()

20200421更新:
HiddenLayer,torchwatch

20201203更新:
https://discuss.pytorch.org/t/how-do-i-check-the-number-of-parameters-of-a-model/4325/22
https://pypi.org/project/pytorch-model-summary/

def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

from prettytable import PrettyTable

def count_parameters(model):
    table = PrettyTable(["Modules", "Parameters"])
    total_params = 0
    for name, parameter in model.named_parameters():
        if not parameter.requires_grad: 
            continue
        param = parameter.numel()
        table.add_row([name, param])
        total_params+=param
    print(table)
    print(f"Total Trainable Params: {total_params}")
    return total_params

summary(model, *inputs, batch_size=-1, show_input=False, show_hierarchical=False,
        print_summary=False, max_depth=1, show_parent_layers=False):
        
# summary的例子
import torch
import torch.nn as nn
import torch.nn.functional as F

from pytorch_model_summary import summary


class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        self.conv2_drop = nn.Dropout2d()
        self.fc1 = nn.Linear(320, 50)
        self.fc2 = nn.Linear(50, 10)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
        x = x.view(-1, 320)
        x = F.relu(self.fc1(x))
        x = F.dropout(x, training=self.training)
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)


# # show input shape
# print(summary(Net(), torch.zeros((1, 1, 28, 28)), show_input=True))

# # show output shape
print(summary(Net(), torch.zeros((1, 1, 28, 28)), show_input=False))

# show output shape and hierarchical view of net
print(summary(Net(), torch.zeros((1, 1, 28, 28)), show_input=True, show_hierarchical=True))
PSPNet(Pyramid Scene Parsing Network)是一种用于图像分割的深度学习模型,它在场景解析任务表现突出,能够处理不同尺度的信息PyTorch是一个广泛使用的深度学习框架,它提供了构建和训练复杂神经网络的工具。将PyTorch模型转换为ONNX(Open Neural Network Exchange)格式是一种常见做法,旨在实现模型在不同深度学习框架之间的兼容性,从而能够在不同的推理引擎上部署和执行。 要使用`pspnet-pytorch-master`模型进行ONNX推理,你需要遵循以下步骤: 1. **模型准备**:确保你已经安装了PyTorch,并且已经获取了`pspnet-pytorch-master`模型的代码和预训练权重。这通常涉及到克隆GitHub仓库并安装所需的依赖项。 2. **模型转换**:使用PyTorch的ONNX导出功能,将模型转换为ONNX格式。这需要在PyTorch运行模型并捕获模型的输出,以生成ONNX模型文件。 3. **验证转换**:在转换模型后,你应该验证转换后的ONNX模型是否能够正确地执行推理,与原PyTorch模型的输出保持一致。 4. **ONNX推理**:一旦确认ONNX模型无误,就可以使用支持ONNX的推理引擎(如ONNX Runtime, TensorRT等)来进行高效的推理。 以下是一个简化的代码示例,展示了如何将PyTorch模型导出为ONNX格式: ```python import torch from pspnet import PSPNet # 假设你已经导入了PSPNet类 # 加载预训练的PSPNet模型 model = PSPNet() # 假设这里已经加载了预训练权重 model.eval() # 设置为评估模式 # 创建一个dummy输入,用于模型的前向传播,以便生成ONNX模型 dummy_input = torch.randn(1, 3, 475, 475) # 这里的维度可能需要根据实际模型调整 # 导出模型到ONNX torch.onnx.export(model, dummy_input, "pspnet.onnx", verbose=True, input_names=['input'], output_names=['output']) # 使用ONNX模型进行推理 import onnx import onnxruntime # 加载ONNX模型 onnx_model = onnx.load("pspnet.onnx") onnx.checker.check_model(onnx_model) # 使用ONNX Runtime进行推理 ort_session = onnxruntime.InferenceSession("pspnet.onnx") ort_inputs = {ort_session.get_inputs()[0].name: dummy_input.numpy()} ort_outs = ort_session.run(None, ort_inputs) ```
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值