本文聚焦于一种常见需求:使用预训练的resnet50网络,但是只用来提取特征而不要分类结果。
ResNet的类定义为:
class ResNet(nn.Module):
def __init__(self, block, layers, num_classes=1000):
pass
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.avgpool(x)
x = x.view(x.size(0), -1)
x = self.fc(x)
return x
其实我们不想修改很多,只是把forward函数改一改而已。
1. 重新定义一个网络
class MyResNet(ResNet):
def __init__(self, block, layers, num_classes=1000):
pass
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.lay