pytorch的vgg19的预训练模型提取图片特征

概要

推荐中item侧的特征,item的头图、封面等是一个比较好的特征。但当物料比较少时不可能自己搭建模型训练提取图像特征。比较好的处理方法就是运用torchvision.models模块中的预训练好的模型进行特征提取。
最近,在相关的内容,并且在提取特征时遇到一个比较隐形的坑。记录一下,希望可以帮到遇到同样问题的小伙伴。

内容

torchvision.models集成了很多模型,大家选择适合自己业务场景的模型。本文选用的是vgg19。代码如下:

import torch
import torch.nn as nn
from torchvision import models, transforms
from torch.autograd import Variable
import numpy as np
from PIL import Image
import pandas as pd

# pretrained参数表示是否运用预训练模型
vgg_model = models.vgg19(pretrained=True)
# 这里重新定义最后一层,默认层输出的维度是1000,我这里重新定义可以输出自己想要的维度。
new_classifier = torch.nn.Sequential(*list(vgg_model.children())[-1][:6])
# (6): Linear(in_features=4096, out_features=1000, bias=True)
new_classifier.add_module("Linear output", torch.nn.Linear(4096, 256))
vgg_model.classifier = new_classifier
trans = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
# 这里convert成三通道,如果本来就是三通道的图片,则可以省略。
im = Image.open(image_dir).convert('RGB')
im = trans(im)
im.unsqueeze_(dim=0)
# Set model to eval mode 这一步很重要
vgg_model = vgg_model.eval()
y = vgg_model(im).data.numpy().tolist()

以上代码就是利用vgg19提取图像特征所有代码。
其中比较要注意的地方:
1、图像转换成RGB三通道,有些图片因为格式等其他问题导致不符合输入要求,容易导致报错。
2、vgg_model = vgg_model.eval()这个代码很重要,因为没有这行代码也能运行成功,但是输出的结果不稳定,同一张图像两次执行的结果不一致。这是因为一些模型使用的模块有不用的训练和评估行为,比如批量正则。需要正确的使用model.train()或model.eval()。
在这里插入图片描述
参考:
1、 Using the pre-trained models
2、PyTorch中文文档

  • 2
    点赞
  • 23
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
### 回答1: PyTorch 提供了许多训练模型,如 AlexNet、VGG、ResNet、Inception 等,这些模型都在 ImageNet 数据集上进行了训练。我们可以利用这些训练模型提取图像特征,以便用于图像分类、目标检测等任务。 以下是一个示例代码,利用 ResNet-50 模型提取图像特征: ```python import torch import torchvision.models as models import torchvision.transforms as transforms # 加载训练模型 resnet = models.resnet50(pretrained=True) # 定义数据处理 transform = transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) # 加载图像 img = Image.open('test.jpg') # 处理图像 img_tensor = transform(img) # 增加一个维度,变成 4D 张量 img_tensor.unsqueeze_(0) # 特征提取 features = resnet(img_tensor) # 打印特征向量 print(features) ``` 其中,我们首先加载了 ResNet-50 模型,并定义了一个数据处理方法 `transform`,然后加载了一张测试图片,并将其转化为 PyTorch Tensor 格式,并增加了一个维度,变成 4D 张量。最后,我们通过调用 `resnet` 模型提取特征,得到一个 1x1000 的张量,我们可以将其用于图像分类等任务中。 ### 回答2: PyTorch是一个功能强大的机器学习库,其中包含许多用于训练模型特征提取工具。 训练模型是在大规模数据集上进行训练并保存的模型,可以用来处理各种任务。PyTorch提供了许多经过训练模型,如ResNet、Inception、VGG等,这些模型具有很强的特征提取能力。 使用PyTorch进行训练模型特征提取很简单。首先,我们需要下载和加载所需的训练模型PyTorch提供了一种方便的方式,可以直接从网上下载训练模型并加载到我们的程序中。 加载训练模型后,我们可以通过简单地将数据传递给该模型提取特征。这通常涉及将输入数据通过模型的前向传播过程,并从中获取感兴趣的特定层或层的输出。 例如,如果我们想要提取图像的特征,我们可以使用ResNet模型。我们可以将图片传递给该模型,然后从所需的层中获取输出。这些特征可以用来训练其他模型,进行图像分类、目标检测等任务。 PyTorch训练模型特征提取功能很受欢迎,因为它不需要从头开始训练模型,而是利用了已经学习到的知识。这样可以节省时间和计算资源。此外,训练模型通常在大规模数据集上进行了训练,因此其特征提取能力很强。 总而言之,PyTorch提供了简单且强大的训练模型特征提取工具,可以用于各种任务。通过加载训练模型提取特征,我们可以快速构建和训练其他模型,从而提高模型性能。 ### 回答3: PyTorch 提供了许多训练模型,它们通过在大规模数据集上进行训练,能够有效捕捉到图像或文本等数据的特征训练模型特征提取是指利用这些模型提取输入数据的特征表示。 在 PyTorch 中,我们可以使用 torchvision 包提供的训练模型。这些模型包括常用的卷积神经网络(如 ResNet、VGG)和循环神经网络(如 LSTM、GRU)等,它们在 ImageNet 数据集上进行了大规模的训练。 为了使用训练模型进行特征提取,我们可以简单地加载模型提取输入数据的中间层输出。这些中间层的输出通常被认为是数据的有意义的特征表示。例如,对于图像分类任务,我们可以加载训练的 ResNet 模型,并通过前向传播得到图像在最后一层卷积层的输出(也称为特征图)。这些特征图可以被视为图像的高级特征表示,可以用于后续的任务,如图像检索或分类等。 通过使用训练模型进行特征提取,我们可以获得一些优势。首先,训练模型经过大规模数据集的训练,能够捕捉到通用的特征表示。这样,我们无需从零开始训练模型,可以在少量的数据上进行微调或直接使用。其次,特征提取能够减少计算量和内存消耗,因为我们只需运行输入数据的前向传播,并截取中间层的输出,而无需通过后向传播进行反向更新。 总之,PyTorch 提供了方便的接口和训练模型,使得特征提取变得简单且高效。通过使用训练模型,我们可以获得数据的有意义的特征表示,并在后续的任务中得到更好的性能。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值