测试代码及效果:
if __name__ == '__main__':
from torchvision import models
net = models.vgg16(pretrained=True)
summary(net, 3, 224, 224,
image='part3/valid_demo/valid_pic/dog1.jpg', # 传入一张图片作为输入
save_fig=True # 保存提取特征过程的输出图像
)
代码如下:
import os
import torch
from collections import Iterable
from PIL import Image
from torchvision.transforms import ToTensor, Resize
from torchvision.utils import save_image
# 使用tabulate包以表格显示最终结果
from tabulate import tabulate
def summary(model, *CHW, image=None