pytorch 提取中间层的特征

一、背景

需要提取网络中间层的特征,用于特征工程或者可视化

二、解决方案

先说好,有很多解决的方法呢,这里给出一种我认为是简单的,官方提供的功能

https://pytorch.org/vision/main/generated/torchvision.models.feature_extraction.create_feature_extractor.html#torchvision.models.feature_extraction.create_feature_extractor

核心代码如下

from torchvision.models.feature_extraction import create_feature_extractor

 # Feature extraction with resnet
model = torchvision.models.resnet18()
# extract layer1 and layer3, giving as names `feat1` and feat2`
model = create_feature_extractor(
	model, {'layer1': 'feat1', 'layer3': 'feat2'})
out = model(torch.rand(1, 3, 224, 224))
print([(k, v.shape) for k, v in out.items()])
#     [('feat1', torch.Size([1, 64, 56, 56])),
#     ('feat2', torch.Size([1, 256, 14, 14]))]

评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值