import torch
import torch.nn as nn
import os
from torchvision import models, transforms
from torch.autograd import Variable
import numpy as np
from PIL import Image
import torchvision.models as models
import pandas as pd
file_path='./images/'
save_path = ''
transform1 = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor()])
names = os.listdir(file_path)
resnet50 = models.resnet50(pretrained=True)
feature_extractor = torch.nn.Sequential(*list(resnet50.children())[:-1]) #去掉最后的fc层
for name in names:
pic=file_path+name
img = Image.open(pic)
img1 = transform1(img)
x = Variable(torch.unsqueeze(img1, dim=0).float(), requires_grad=False)
y = feature_extractor(x).squeeze().cpu() #去掉多余的一维
torch.save(y,save_path+name[:-4]+".pth")
如何用resnet50提取图片特征【咨询大厂大佬版】
最新推荐文章于 2024-04-11 16:21:49 发布