https://github.com/ducha-aiki/affnet
import torch
import torch.nn as nn
import numpy as np
import sys
import os
import time
from PIL import Image
from torch.autograd import Variable
import torch.nn.functional as F
from examples.hesaffnet.SparseImgRepresenter import ScaleSpaceAffinePatchExtractor
from examples.hesaffnet.LAF import denormalizeLAFs, LAFs2ell, abc2A
from examples.hesaffnet.Utils import line_prepender
from examples.hesaffnet.architectures import AffNetFast
from examples.hesaffnet.HardNet import HardNet
import matplotlib.pyplot as plt
import cv2
def showimg(dst, name):
# plt.figure()
# plt.imshow(dst)
# plt.savefig("./save_img/"+str(name)+".png", dpi=300)plt.imshow(dst)
# plt.show()
src = cv2.cvtColor(dst, cv2.COLOR_RGB2BGR)
cv2.imwrite("./save_img/"+str(name)+".jpg", src)
USE_CUDA = False
### Initialization
AffNetPix = AffNetFast(PS=32)
weightd_fname = './pretrained/AffNet.pth'
checkpoint = torch.load(weightd_fname)
AffNetPix.load_state_dict(checkpoint['state_dict'])
AffNetPix.eval()
detector = ScaleSpaceAffinePatchExtractor(mrSize=5.192, num_features=3000,
border=5, num_Baum_iters=1,
AffNet=AffNetPix)
descriptor = HardNet()
model_weights = './HardNet++.pth'
hncheckpoint = torch.load(model_weights)
descriptor.load_state_dict(hncheckpoint['state_dict'])
descriptor.eval()
if USE_CUDA:
detector = detector.cuda()
descriptor = descriptor.cuda()
#Image loading
input_img_fname1 = 'examples/hesaffnet/img/1_IMG_Texture_8Bit.png'#sys.argv[1]
input_img_fname2 = 'examples/hesaffnet/img/7_IMG_Texture_8Bit.png'#sys.argv[2]
output_img_fname = 'examples/hesaffnet/kpi_match.png'#sys.argv[3]
'''
[793, 579, 1508, 1070]
[783, 550, 1483, 1062]
'''
ori_image = []
def load_grayscale_var(fname, x):
img = Image.open(fname).convert('RGB')
if img is None:
exit()
arrayimg = np.array(img)
new = arrayimg[x[1]:x[3], x[0]:x[2], :]
ori_image.append(new)
img = np.mean(new, axis = 2)
#特征
var_image = torch.autograd.Variable(torch.from_numpy(img.astype(np.float32)), volatile = True)
var_image_reshape = var_image.view(1, 1, var_image.size(0),var_image.size(1))
if USE_CUDA:
var_image_reshape = var_image_reshape.cuda()
return var_image_reshape
img1 = load_grayscale_var(input_img_fname1, [793, 579, 1508, 1070])
img2 = load_grayscale_var(input_img_fname2, [783, 550, 1483, 1062])
## Detection and description
def get_geometry_and_descriptors(img, det, desc):
with torch.no_grad():
LAFs, resp = det(img, do_ori = True)
patches = detector.extract_patches_from_pyr(LAFs, PS = 32)
descriptors = descriptor(patches)
return LAFs, descriptors
LAFs1, descriptors1 = get_geometry_and_descriptors(img1, detector, descriptor)
LAFs2, descriptors2 = get_geometry_and_descriptors(img2, detector, descriptor)
print(LAFs1.shape)
print(descriptors1.shape)
#Bruteforce matching with SNN threshold
from examples.hesaffnet.Losses import distance_matrix_vector
SNN_threshold = 0.8
dist_matrix = distance_matrix_vector(descriptors1, descriptors2)
min_dist, idxs_in_2 = torch.min(dist_matrix,1)
dist_matrix[:,idxs_in_2] = 100000;# mask out nearest neighbour to find second nearest
min_2nd_dist, idxs_2nd_in_2 = torch.min(dist_matrix,1)
mask = (min_dist / (min_2nd_dist + 1e-8)) <= SNN_threshold
tent_matches_in_1 = indxs_in1 = torch.autograd.Variable(torch.arange(0, idxs_in_2.size(0)), requires_grad = False)[mask]
tent_matches_in_2 = idxs_in_2[mask]
tent_matches_in_1 = tent_matches_in_1.data.cpu().long()
tent_matches_in_2 = tent_matches_in_2.data.cpu().long()
from examples.hesaffnet.LAF import visualize_LAFs, convertLAFs_to_A23format, LAF2pts
import seaborn as sns
image1 = img1.cpu().numpy().squeeze()
image2 = img2.cpu().numpy().squeeze()
LAF1 = LAFs1[tent_matches_in_1,:,:].cpu().numpy().squeeze()
LAF2 = LAFs2[tent_matches_in_2,:,:].cpu().numpy().squeeze()
# x*2*3
work_LAFs1 = convertLAFs_to_A23format(LAF1)
work_LAFs2 = convertLAFs_to_A23format(LAF2)
dst1 = np.zeros((image1.shape[0]+image2.shape[0], image1.shape[1]+image2.shape[1], 3), dtype = np.uint8)
dst1[0:image1.shape[0],0:image1.shape[1],:] = ori_image[0]
dst1[image1.shape[0]:image1.shape[0]+image2.shape[0],image1.shape[1]:image1.shape[1]+image2.shape[1],:] = ori_image[1]
showimg(dst1, 'dst1')
dst2 = np.zeros((image1.shape[0]+image2.shape[0], image1.shape[1]+image2.shape[1], 3), dtype = np.uint8)
dst2[0+image2.shape[0]:image1.shape[0]+image2.shape[0],0:image1.shape[1],:] = ori_image[0]
dst2[0:image2.shape[0],0+image1.shape[1]:image1.shape[1]+image2.shape[1],:] = ori_image[1]
showimg(dst2, 'dst1')
max_width = image1.shape[1] if image1.shape[1]>image2.shape[1] else image2.shape[1]
dst3 = np.zeros((image1.shape[0]+image2.shape[0], max_width, 3), dtype = np.uint8)
dst3[0:image1.shape[0],0:image1.shape[1],:] = ori_image[0]
dst3[image1.shape[0]:image1.shape[0]+image2.shape[1],0:image2.shape[1],:] = ori_image[1]
showimg(dst3, 'dst3')
max_hight = image1.shape[0] if image1.shape[0]>image2.shape[0] else image2.shape[0]
dst4 = np.zeros((max_hight, image1.shape[1]+image2.shape[1], 3), dtype = np.uint8)
dst4[0:image1.shape[0],0:image1.shape[1],:] = ori_image[0]
dst4[0:image2.shape[0],image1.shape[1]:image1.shape[1]+image2.shape[1],:] = ori_image[1]
showimg(dst4, 'dst4')
a = work_LAFs1[:, 0, 2]
src_pts = [work_LAFs1[:, 0, 2],work_LAFs1[:, 1, 2]]
src_pts = np.array(src_pts, dtype=np.float32).reshape((-1, 1, 2))
dst_pts = [work_LAFs2[:, 0, 2],work_LAFs2[:, 1, 2]]
dst_pts = np.array(dst_pts, dtype=np.float32).reshape((-1, 1, 2))
homo, mask = cv2.findHomography(src_pts, dst_pts, cv2.RANSAC, 60)
thickness = 2
lineType = 8
for i in range(len(work_LAFs1)):
ell1 = LAF2pts(work_LAFs1[i, :, :])
ell2 = LAF2pts(work_LAFs2[i, :, :])
if mask[i] == 0:
continue
ptStart1 = (int(work_LAFs1[i, :, :][0][2]),int(work_LAFs1[i, :, :][1][2]))
ptEnd1 = (int(work_LAFs2[i, :, :][0][2]+image1.shape[1]),int(work_LAFs2[i, :, :][1][2]+image1.shape[0]))
cv2.line(dst1, ptStart1, ptEnd1, (255, 255, 0), thickness, lineType)
ptStart2 = (int(work_LAFs1[i, :, :][0][2]),int(work_LAFs1[i, :, :][1][2]+image2.shape[0]))
ptEnd2 = (int(work_LAFs2[i, :, :][0][2]+image1.shape[1]),int(work_LAFs2[i, :, :][1][2]))
cv2.line(dst2, ptStart2, ptEnd2, (255, 255, 0), thickness, lineType)
ptStart3 = (int(work_LAFs1[i, :, :][0][2]),int(work_LAFs1[i, :, :][1][2]))
ptEnd3 = (int(work_LAFs2[i, :, :][0][2]),int(work_LAFs2[i, :, :][1][2]+image1.shape[0]))
cv2.line(dst3, ptStart3, ptEnd3, (255, 255, 0), thickness, lineType)
ptStart4 = (int(work_LAFs1[i, :, :][0][2]),int(work_LAFs1[i, :, :][1][2]))
ptEnd4 = (int(work_LAFs2[i, :, :][0][2]+image1.shape[1]),int(work_LAFs2[i, :, :][1][2]))
cv2.line(dst4, ptStart4, ptEnd4, (255, 255, 0), thickness, lineType)
# for j in range(len(ell1)):
# cv2.circle(dst, (int(ell1[j][0]), int(ell1[j][1])), 1, point_color, thickness)
showimg(dst1, 'dst11')
showimg(dst2, 'dst22')
showimg(dst3, 'dst33')
showimg(dst4, 'dst44')
#plt.plot(ell[:, 0], ell[:, 1], color)
#visualize_LAFs(img1.cpu().numpy().squeeze(), LAFs1[tent_matches_in_1,:,:].cpu().numpy().squeeze(), 'g', show=True)
#visualize_LAFs(img2.cpu().numpy().squeeze(), LAFs2[tent_matches_in_2,:,:].cpu().numpy().squeeze(), 'g', show=True)