下载模型
wget https://digix-algo-challenge.obs.cn-east-2.myhuaweicloud.com/2020/cv/6rKDTsB6sX8A1O2DA2IAq7TgHPdSPxJF/train_data.zip -o train_data.zip
ResNet提取图片特征向量
from torch.autograd import Variable
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import pretrainedmodels
from PIL import Image
import ssl
ssl._create_default_https_context = ssl._create_unverified_context
TARGET_IMG_SIZE = 224
img_to_tensor = transforms.ToTensor()
def get_seresnet50():
encoder = pretrainedmodels.se_resnet50()
model = nn.Sequential(encoder.layer0,
encoder.layer1,
encoder.layer2,
encoder.layer3,
encoder.layer4,
encoder.avg_pool # 平均池化,张成一个[batchSize,2048]的特征向量
)
for param in model.parameters():
param.requires_grad = False
# model.cuda() # 使用GPU,CPU版去掉
model.eval()
return model
# 特征提取
def extract_feature(model, imgpath):
img = Image.open(imgpath) # 读取图片
img = img.resize((TARGET_IMG_SIZE, TARGET_IMG_SIZE))
tensor = img_to_tensor(img) # 将图片矩阵转化成tensor
# tensor = tensor.cuda() # GPU 使用GPU放开此注释
tensor = torch.unsqueeze(tensor, 0)
result = model(Variable(tensor))
result_npy = result.data.cpu().numpy()[0].ravel().tolist()
return result_npy
if __name__ == '__main__':
model = get_seresnet50()
feature = extract_feature(model, "1.jpeg")
print(len(feature))
print(feature)
# 下载模型可能会报错,将上面下载的模型放在下载的目录
参考
https://www.codenong.com/cs107121229/