torchsummary
要使用 Jupyter Notebook 绘制一个神经网络的结构图,可以使用 `torchsummary` 库中的 `summary` 函数。该函数可以显示模型的结构以及每一层的输出形状等信息。首先,确保你已经安装了 `torchsummary`:
pip install torchsummary
然后,在 Jupyter Notebook 中运行以下代码,即可显示模型结构图:
from torchsummary import summary
import torch
import torch.nn as nn
class FeatureExtractor_01(nn.Module):
def __init__(self, in_channel=5, kernel_size=3, stride=1, padding=2, mp_kernel_size=2, mp_stride=2):
super(FeatureExtractor_01, self).__init__()
layer1 = nn.Sequential(
nn.Conv1d(5, 32, kernel_size=kernel_size, stride=stride, padding=padding),
nn.BatchNorm1d(32),
nn.ReLU(inplace=True),
nn.MaxPool1d(kernel_size=mp_kernel_size, stride=mp_stride))
layer2 = nn.Sequential(
nn.Conv1d(32, 64, kernel_size=kernel_size, stride=stride, padding=padding),
nn.BatchNorm1d(64),
nn.ReLU(inplace=True),
nn.MaxPool1d(kernel_size=mp_kernel_size, stride=mp_stride))
layer3 = nn.Sequential(
nn.Conv1d(64, 128, kernel_size=kernel_size, stride=stride, padding=padding),
nn.BatchNorm1d(128),
nn.ReLU(inplace=True),
nn.MaxPool1d(kernel_size=mp_kernel_size, stride=mp_stride))
layer4 = nn.Sequential(
nn.Conv1d(128, 256, kernel_size=kernel_size, stride=stride, padding=padding),
nn.BatchNorm1d(256),
nn.ReLU(inplace=True),
nn.MaxPool1d(kernel_size=mp_kernel_size, stride=mp_stride))
layer5 = nn.Sequential(
nn.Conv1d(256, 512, kernel_size=kernel_size, stride=stride, padding=padding),
nn.BatchNorm1d(512),
nn.ReLU(inplace=True),
nn.AdaptiveMaxPool1d(1),
nn.Flatten())
self.fs = nn.Sequential(
layer1,
layer2,
layer3,
layer4,
layer5,
)
def forward(self, tar, x=None, y=None):
h = self.fs(tar)
return h
# Create an instance of the model
model = FeatureExtractor_01()
# Display the model summary
summary(model, (5, 100))
这将在输出中显示类似于以下内容的模型结构信息:
----------------------------------------------------------------
Layer (type) Output Shape Param #
================================================================
Conv1d-1 [-1, 32, 104] 512
BatchNorm1d-2 [-1, 32, 104] 64
ReLU-3 [-1, 32, 104] 0
MaxPool1d-4 [-1, 32, 52] 0
Conv1d-5 [-1, 64, 56] 12,352
BatchNorm1d-6 [-1, 64, 56] 128
ReLU-7 [-1, 64, 56] 0
MaxPool1d-8 [-1, 64, 28] 0
Conv1d-9 [-1, 128, 32] 24,704
BatchNorm1d-10 [-1, 128, 32] 256
ReLU-11 [-1, 128, 32] 0
MaxPool1d-12 [-1, 128, 16] 0
Conv1d-13 [-1, 256, 20] 98,560
BatchNorm1d-14 [-1, 256, 20] 512
ReLU-15 [-1, 256, 20] 0
MaxPool1d-16 [-1, 256, 10] 0
Conv1d-17 [-1, 512, 14] 393,728
BatchNorm1d-18 [-1, 512, 14] 1,024
ReLU-19 [-1, 512, 14] 0
AdaptiveMaxPool1d-20 [-1, 512, 1] 0
Flatten-21 [-1, 512] 0
================================================================
Total params: 531,840
Trainable params: 531,840
Non-trainable params: 0
----------------------------------------------------------------