"""
@File : 基于resnet18的自然图像检索.py
@Time : 2021-04-18 14:25
@Author : XD
@Email : gudianpai@qq.com
@Software: PyCharm
"""
from tqdm import tqdm
import glob
from sklearn. preprocessing import normalize
import matplotlib. pyplot as plt
import numpy as np
import cv2
import torch
import torch. nn as nn
from torchvision import models
model = models. resnet18( pretrained= True )
model = nn. Sequential( * list ( model. children( ) ) [ : - 1 ] )
def show_img ( ori_path, aim_path) :
plt. figure( figsize= ( 20 , 10 ) )
img = cv2. imread( ori_path)
img = cv2. cvtColor( img, cv2. COLOR_BGR2RGB)
plt. subplot( 3 , 5 , 1 )
plt. imshow( img)
plt. xticks( [ ] )
plt. yticks( [ ] )
plt. title( 'Original Image' )
for idx in range ( 10 ) :
img = cv2. imread( aim_path[ idx] )
img = cv2. cvtColor( img, cv2. COLOR_BGR2RGB)
plt. subplot( 3 , 5 , 6 + idx)
plt. imshow( img)
plt. xticks( [ ] )
plt. yticks( [ ] )
plt. show( )
def cnn_feat_method ( img) :
input = cv2. resize( img, ( 224 , 224 ) )
input = np. expand_dims( img, 0 )
input = np. transpose( input , [ 0 , 3 , 1 , 2 ] )
input = input . astype( np. float32)
input = torch. from_numpy( input )
with torch. no_grad( ) :
feat = model( input )
feat = feat. flatten( )
feat = feat. data. numpy( )
return feat
def main ( ) :
jpgs = glob. glob( "./ImgDB/*.jpg" )
img_array = [ ]
for path in tqdm( jpgs[ : ] ) :
img = cv2. imread( path)
img = cv2. cvtColor( img, cv2. COLOR_BGR2RGB)
img = cv2. resize( img, ( 224 , 224 ) )
img_array. append( img)
img_array = np. stack( img_array)
img_array = np. transpose( img_array, [ 0 , 3 , 1 , 2 ] ) . astype( np. float32)
img_array = torch. from_numpy( img_array)
cnn_feat = [ ]
step = int ( len ( img_array) / 50 ) + 1
for i in range ( 0 , len ( img_array) , 50 ) :
with torch. no_grad( ) :
feat = model( img_array[ i: i + 50 ] )
feat = feat. data. numpy( )
cnn_feat. append( feat)
print ( i, i / 50 )
cnn_feat = np. concatenate( cnn_feat, 0 )
cnn_feat = cnn_feat. reshape( - 1 , 512 )
cnn_feat = normalize( cnn_feat)
search_idx = 512
ids = np. dot( cnn_feat[ search_idx] , cnn_feat. T)
ids = np. argsort( ids) [ : : - 1 ] [ 1 : ]
show_img( jpgs[ search_idx] , [ jpgs[ x] for x in ids[ : 10 ] ] )
if __name__ == '__main__' :
main( )
效果还可以
D: \ANACONDA\envs\pytorch_gpu\python. exe G: / 七月在线/ 自然图像搜索/ 基于resnet18的自然图像检索. py
100 % | ██████████| 1000 / 1000 [ 00 : 01 < 00 : 00 , 631. 88it/ s]
0 0.0
50 1.0
100 2.0
150 3.0
200 4.0
250 5.0
300 6.0
350 7.0
400 8.0
450 9.0
500 10.0
550 11.0
600 12.0
650 13.0
700 14.0
750 15.0
800 16.0
850 17.0
900 18.0
950 19.0
Process finished with exit code 0