我们知道,Keras有一个非常有好的功能是summary,可以打印显示网络结构和参数,一目了然。但是,Pytorch本身好像不支持这一点。不过,幸好有一个工具叫torchsummary,可以实现和Keras几乎一样的效果。
pip install torchsummary
然后我们定义好网络结构之后,就可以用summary来打印显示了。假设我们定义的网络结构是一个叫Generator的类。
import torch
from torchsummary import summary
# 需要使用device来指定网络在GPU还是CPU运行
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
netG_A2B = Generator(3, 3).to(device)
summary(netG_A2B, input_size=(3, 256, 256))
之后,就可以打印网络结构了。一个示例结构如下:
----------------------------------------------------------------
Layer (type) Output Shape Param #
================================================================
Conv2d-1 [-1, 64, 224, 224] 1,792
ReLU-2 [-1,