给定一张指定图片“22.png”,在指定图片库“database目录”中检索出与其相似度最高的3张图片。
1. 使用深度神经网络提取图片特征
1.1 vgg16提取图片特征
# -*- coding: UTF-8 -*-
import numpy as np
import h5py
import matplotlib.image as mpimg
import matplotlib.pyplot as plt
from keras.applications.vgg16 import VGG16
from keras.applications.vgg16 import preprocess_input as preprocess_input_vgg
from keras.preprocessing import image
class VGGNet:
def __init__(self):
self.input_shape = (224, 224, 3)
self.weight = 'imagenet' # None代表随机初始化,即不加载预训练权重
self.pooling = 'max' # avg
self.model_vgg = VGG16(weights=self.weight,
input_shape=(self.input_shape[0], self.input_shape[1], self.input_shape[2]),
pooling=self.pooling,
include_top=False)
# self.model_vgg.predict(np.zeros((1, 224, 224, 3)))
# 提取vgg16最后一层卷积特征( Use vgg16/Resnet model to extract features Output normalized feature vector)
def vgg_extract_feat(self, img_path):
img = image.load_img(img_path, target_size=(self.input_shape[0], self.input_shape[1]))
img = image.img_to_array(img)
img = np.expand_dims(img, axis=0)
img = preprocess_input_vgg(img)
feat = self.model_vgg.predict(img)
# print(feat.shape)
norm_feat = feat[0] / np.linalg.norm(feat[0])
return norm_feat
1.2 resnet50提取图片特征
from keras.applications.resnet50 import ResNet50
from keras.applications.resnet50 import preprocess_input as preprocess_input_resnet
self.model_resnet = ResNet50(weights = self.weight,
input_shape = (self.input_shape[0], self.input_shape[1], self.input_shape[2]),
pooling = self.pooling,
include_top = False)
# 提取resnet50最后一层卷积特征
def resnet_extract_feat(self, img_path):
img = preprocess_input_resnet(img)
feat = self.model_resnet.predict(img)
1.3 densenet121提取图片特征
from keras.applications.densenet import DenseNet121
from keras.applications.densenet import preprocess_input as preprocess_input_densenet
self.model_densenet = DenseNet121(weights = self.weight,
input_shape = (self.input_shape[0], self.input_shape[1], self.input_shape[2]),
pooling = self.pooling,
include_top = False)
# 提取densenet121最后一层卷积特征
def densenet_extract_feat(self, img_path):
img = preprocess_input_densenet(img)
feat = self.model_densenet.predict(img)
2. 从图像库抽取特征
使用深度神经网络从database目录提取每张图片的name和feature。
def save_features():
database = 'database'
# directory for storing extracted features
index = 'models/vgg_featureCNN.h5'
# Returns a list of filenames for all jpg images in a directory.
img_list = [os.path.join(database, f) for f in os.listdir(database) if f.endswith('.jpg')]
feats, names = extract_features_and_images_index(img_list)
# writing feature extraction results
h5f = h5py.File(index, 'w')
h5f.create_dataset('dataset_1', data=np.array(feats))
h5f.create_dataset('dataset_2', data=np.string_(names))
h5f.close()
结果写入hdf5
def extract_features_and_images_index(img_path_list):
feats = []
names = []
model = VGGNet()
for i, img_path in enumerate(img_path_list):
norm_feat = model.vgg_extract_feat(img_path) # 修改此处改变提取特征的网络
img_name = os.path.split(img_path)[1]
feats.append(norm_feat)
names.append(img_name)
print("extracting feature from image No. %d , %d images in total" % ((i + 1), len(img_list)))
return feats, names
3. 加载特征检索相似图片
检索出三张相似度最高的图片
def get_similarity_top3_picture(image_path='22.png', maxres=3):
path = 'models/vgg_featureCNN.h5'
feats, names = get_feature_from_hdf5(path)
# init VGGNet16 model
model = VGGNet()
# extract query image's feature, compute simlarity score and sort
img_feat = model.vgg_extract_feat(image_path) # 修改此处改变提取特征的网络
scores = np.dot(img_feat, feats.T)
# scores = np.dot(img_feat, feats.T)/(np.linalg.norm(img_feat)*np.linalg.norm(feats.T))
rank_ID = np.argsort(scores)[::-1]
rank_score = scores[rank_ID]
print (rank_ID)
# [0 3 1 2]
print(rank_score)
# [0.5255763 0.5209291 0.4861027 0.4736392]
# number of top retrieved images to show
imlist = []
for i, index in enumerate(rank_ID[0:maxres]):
imlist.append(names[index])
print("image names: " + str(names[i]) + " scores: %f" % rank_score[i])
# show top #maxres retrieved result one by one
plot_img(i, index)
top_1_score = rank_score[0]
top_1_md5 = str(imlist[0]).split(".")[0].split("'")[1].strip()
return [top_1_md5, top_1_score]
# ['bf43ddd28d6a2544b4ba8f95002674ed', '0.5255763']
从hdf5读取特征
def get_feature_from_hdf5(path):
# read in indexed images' feature vectors and corresponding image names
path = 'models/vgg_featureCNN.h5'
h5f = h5py.File(path, 'r')
feats = h5f['dataset_1'][:]
names = h5f['dataset_2'][:]
h5f.close()
return feats, names
显示图片
import matplotlib.image as mpimg
def plot_img(i, index):
image = mpimg.imread('database/' + str(index, 'utf-8'))
plt.title("search output %d" % (i + 1))
plt.imshow(image)
plt.show()