一、背景
需要提取网络中间层的特征,用于特征工程或者可视化
二、解决方案
先说好,有很多解决的方法呢,这里给出一种我认为是简单的,官方提供的功能
核心代码如下
from torchvision.models.feature_extraction import create_feature_extractor
# Feature extraction with resnet
model = torchvision.models.resnet18()
# extract layer1 and layer3, giving as names `feat1` and feat2`
model = create_feature_extractor(
model, {'layer1': 'feat1', 'layer3': 'feat2'})
out = model(torch.rand(1, 3, 224, 224))
print([(k, v.shape) for k, v in out.items()])
# [('feat1', torch.Size([1, 64, 56, 56])),
# ('feat2', torch.Size([1, 256, 14, 14]))]