Pytorch:利用预训练好的VGG16网络提取图片特征

注意

目前(2019年一月)因为torchvision提供的VGG网络没有训练完全1,不建议使用torchvision提供的预训练模型来进行特征提取,建议先使用别的框架(例如TensorFlow或者Caffe之类的框架)提供的预训练过的模型来进行特征提取。

前言

这里的提取图片特征特指从VGG网络的最后一个conv层进行提取。虽然下面代码里面给出的是VGG16作为例子,其实也可以用其他的已经经过训练了的神经网络,包括自己训练的。

相关代码

模型结构

首先说下加载模型,这里用的是torch官方提供的已经训练好的模型,只需要从torchvision模块导入:

import torchvision.models as models

model = models.vgg16(pretrained=True)

上面的pretrained=True是指使用预训练的权重,可以自己另外加载,但是这里就直接用官方提供的了。在第一次运行的时候会自动下载相应的模型(例如这里就是vgg16),如果弹出了类似“time out”之类的错误的话请运行多一次试试看。通常运行多几次就可以成功将模型下载下来。

然后需要确定的就是模型的结构,只需要:

feature = torch.nn.Sequential(*list(model.children())[:])
print(feature)

例如vgg16的输出是:

Sequential(
  (0): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace)
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace)
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU
  • 65
    点赞
  • 339
    收藏
    觉得还不错? 一键收藏
  • 104
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 104
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值