import torch
import torchvision.models as models
import torchvision.transforms as transforms
from PIL import Image
# Load pre-trained ResNet-101 model
resnet = models.resnet101(pretrained=True)
# Remove the final fully connected layer
feature_extractor = torch.nn.Sequential(*list(resnet.children())[:-1])
# Set model to evaluation mode
feature_extractor.eval()
# Define image transforms
transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
# Load image
img = Image.open(
"/home/a/Downloads/Laysan_Albat1_545.jpg").convert(
'RGB')
# Apply transforms to image
img_tensor = transform(img)
# Add batch dimension to tensor
img_tensor = img_tensor.unsqueeze(0)
# Pass image through ResNet-101 model
with torch.no_grad():
features = feature_extractor(img_tensor)
# Flatten the features into a 1D vector
feature_vector = torch.flatten(features, start_dim=1)
# Print shape of feature vector
print(feature_vector.shape)
Resnet101特征提取2048维度
最新推荐文章于 2024-05-21 22:44:12 发布