PyTorch:利用预训练好的ResNet-152网络提取图片特征

这篇博文相当于是对上一篇博文Pytorch:利用预训练好的VGG16网络提取图片特征 的补充,本文中提到的提取方式与前文中的不同。

另外,因为TorchVision提供的训练好了的ResNet效果不好,所以本文中将会使用由ruotianluo提供的从Caffe转换过来的ResNet模型(具体可以看这个repo,如果好奇怎么转换的话)。

代码

以下代码节选自pytorch-vqapreprocess-images.py,作者是Cyanogenoid

import h5py
from torch.autograd import Variable
import torch.nn as nn
import torch.backends.cudnn as cudnn
import torch.utils.data
import torchvision.models as models
from tqdm import tqdm

import config
import data
import utils
from resnet import resnet as caffe_resnet

class Net(nn.Module):
    def __init__
在深度学习的图像处理中,利用预训练模型提取特征是一项基础但重要的技能。本文推荐的《PyTorch利用Resnet提取特征并保存为txt教程》能够为你提供从环境搭建到模型应用的全流程指导。 参考资源链接:[PyTorch利用Resnet提取特征并保存为txt教程](https://wenku.csdn.net/doc/645cd61395996c03ac3f869e) 首先,确保你已经安装了PyTorch及相关库,并熟悉基本的PyTorch操作。以下是使用预训练ResNet模型提取特征并保存为txt文件的详细步骤: 1. **环境设置**:导入所需的库和模块。 ```python import os import torch import torch.nn as nn import torchvision.models as models import torchvision.transforms as transforms import numpy as np from PIL import Image ``` 2. **定义特征目录**:创建用于存储特征的目录。 ```python features_dir = 'features' os.makedirs(features_dir, exist_ok=True) ``` 3. **图像预处理**:定义图像预处理流程。 ```python preprocess = ***pose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) ``` 4. **加载预训练ResNet模型**:加载预训练ResNet-50模型,并替换最后的全连接层。 ```python resnet = models.resnet50(pretrained=True) num_ftrs = resnet.fc.in_features resnet.fc = nn.Linear(num_ftrs, num_ftrs) # 可以根据需要调整输出特征维度 resnet.eval() # 设置为评估模式,不计算梯度 ``` 5. **提取特征**:定义一个函数来提取图像特征。 ```python def extract_features(image_path): image = Image.open(image_path).convert('RGB') image = preprocess(image).unsqueeze(0) # 增加批次维度 with torch.no_grad(): features = resnet(image) return features ``` 6. **保存特征到txt文件**:将特征保存为txt文件。 ```python def save_features_to_txt(image_path, features): features = features.numpy().flatten().tolist() file_name = os.path.basename(image_path).split('.')[0] + '.txt' with open(os.path.join(features_dir, file_name), 'w') as f: f.write(' '.join(map(str, features))) ``` 7. **遍历图片提取保存特征**:遍历图片路径列表,提取并保存特征。 ```python image_paths = ['path/to/image1.jpg', 'path/to/image2.jpg'] # 替换为实际图片路径 for image_path in image_paths: features = extract_features(image_path) save_features_to_txt(image_path, features) ``` 以上步骤涵盖了从环境准备、模型加载、特征提取到数据保存的完整流程。通过实践这些步骤,你将能够掌握在PyTorch中应用预训练ResNet模型进行图像特征提取的关键技术。为了深入学习和掌握更多的图像处理技巧,建议阅读《PyTorch利用Resnet提取特征并保存为txt教程》来获取更详尽的知识和高级应用实例。 参考资源链接:[PyTorch利用Resnet提取特征并保存为txt教程](https://wenku.csdn.net/doc/645cd61395996c03ac3f869e)
评论 11
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值