在Pytorch开源的网络以及权重的基础上进行特征提取
就用VGG16网络举个例子 官方开源的vgg网络
我们想提取全链接层的特征时,只需要将官方的代码注释掉一部分
def __init__(self, features, num_classes=1000, init_weights=True):
super(VGG, self).__init__()
self.features = features
self.avgpool = nn.AdaptiveAvgPool2d((7, 7))
self.classifier = nn.Sequential(
nn.Linear(512 * 7 * 7, 4096),
nn.ReLU(True),
nn.Dropout(),
nn.Linear(4096, 4096),
# nn.ReLU(True),
# nn.Dropout(),
# nn.Linear(4096, num_classes),
)
if init_weights:
self._initialize_weights()
然后在读取网络权重的时候
def vgg16(pretrained=False, **kwargs):
"""VGG 16-layer model (configuration "D")
Args:
pretrained (bool): If True, returns a mode