转载自:https://blog.csdn.net/LXX516/article/details/80132228
定义一个特征提取的类:
#中间特征提取
class FeatureExtractor(nn.Module):
def __init__(self, submodule, extracted_layers):
super(FeatureExtractor,self).__init__()
self.submodule = submodule
self.extracted_layers= extracted_layers
def forward(self, x):
outputs = []
for name, module in self.submodule._modules.items():
if name is "fc": x = x.view(x.size(0), -1)
x = module(x)
print(name)
if name in self.extracted_layers:
outputs.append(x)
return outputs
#特征输出
myresnet=resnet18(pretrained=False)
myresnet.load_state_dict(torch.load('cafir_resnet18_1.pkl'))
exact_list=["conv1","layer1","avgpool"]
myexactor=FeatureExtractor(myresnet,exact_list)
x=myexactor(img)
在这里主要应用的是:
for nama, module in model._modules.items():
所以要根据自己的情况重写这个类,这个类提供个一个很不错的想法