注意
目前(2019年一月)因为torchvision提供的VGG网络没有训练完全1,不建议使用torchvision提供的预训练模型来进行特征提取,建议先使用别的框架(例如TensorFlow或者Caffe之类的框架)提供的预训练过的模型来进行特征提取。
前言
这里的提取图片特征特指从VGG网络的最后一个conv层进行提取。虽然下面代码里面给出的是VGG16作为例子,其实也可以用其他的已经经过训练了的神经网络,包括自己训练的。
相关代码
模型结构
首先说下加载模型,这里用的是torch官方提供的已经训练好的模型,只需要从torchvision模块导入:
import torchvision.models as models
model = models.vgg16(pretrained=True)
上面的pretrained=True
是指使用预训练的权重,可以自己另外加载,但是这里就直接用官方提供的了。在第一次运行的时候会自动下载相应的模型(例如这里就是vgg16),如果弹出了类似“time out”之类的错误的话请运行多一次试试看。通常运行多几次就可以成功将模型下载下来。
然后需要确定的就是模型的结构,只需要:
feature = torch.nn.Sequential(*list(model.children())[:])
print(feature)
例如vgg16的输出是:
Sequential(
(0): Sequential(
(0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): ReLU(inplace)
(2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(3): ReLU(inplace)
(4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(6): ReLU