用LoFTR模型得到两张图的匹配点(是叫这个术语吧??)
img0 = torch.from_numpy(img0_raw)[None][None].cuda() / 255.
img1 = torch.from_numpy(img1_raw)[None][None].cuda() / 255.
batch = {'image0': img0, 'image1': img1}
# Inference with LoFTR and get prediction
with torch.no_grad():
matcher(batch)
mkpts0 = batch['mkpts0_f'].cpu().numpy()
mkpts1 = batch['mkpts1_f'].cpu().numpy()
mconf = batch['mconf'].cpu().numpy()
fig = make_matching_figure(img0_raw, img1_raw, mkpts0, mkpts1, color, text)
会得到 图片以及对应的点(list):img0_raw, img1_raw, mkpts0, mkpts1
绘图:
def make_matching_figure(img0, img1, mkpts0, mkpts1, color, text=[], path=None):
# draw image pair
fig, axes = plt.subplots(2, 2, figsize=(10, 6), dpi=75)
axes[0,0].imshow(img0, cmap='gray')
axes[0,1].imshow(img1, cmap='gray')
for i in range(2): # clear all frames
axes[0,i].get_yaxis().set_ticks([])
axes[0,i].get_xaxis().set_ticks([])
for spine in axes[0,i].spines.values():
spine.set_visible(False)
plt.tight_layout(pad=1)
# draw matches
fig.canvas.draw()
transFigure = fig.transFigure.inverted()
fkpts0 = transFigure.transform(axes[0,0].transData.transform(mkpts0))
fkpts1 = transFigure.transform(axes[0,1].transData.transform(mkpts1))
fig.lines = [matplotlib.lines.Line2D((fkpts0[i, 0], fkpts1[i, 0]), (fkpts0[i, 1], fkpts1[i, 1]),
transform=fig.transFigure, c=color[i], linewidth=1) for i in range(len(mkpts0))]
axes[0,0].scatter(mkpts0[:, 0], mkpts0[:, 1], c=color, s=4)
axes[0,1].scatter(mkpts1[:, 0], mkpts1[:, 1], c=color, s=4)
# put txts
txt_color = 'k' if img0[:100, :200].mean() > 200 else 'w'
fig.text(
0.01, 0.99, '\n'.join(text), transform=fig.axes[0].transAxes,
fontsize=15, va='top', ha='left', color=txt_color)
plt.tight_layout(pad=1)
H, _ = cv2.findHomography(mkpts0,mkpts1, cv2.RANSAC, 2.0)
align_image = cv2.warpPerspective(img0, H,
(img0.shape[1], img0.shape[0]))
axes[1,0].imshow(align_image, cmap='gray')
axes[1,1].imshow(img1, cmap='gray')
plt.show()
# save or return figure
if path:
plt.savefig(str(path), bbox_inches='tight', pad_inches=0)
plt.close()
else:
return fig
报错
Traceback (most recent call last):
File "test_single_pair.py", line 123, in <module>
fig = make_matching_figure(img0_raw, img1_raw, mkpts0, mkpts1, color, text)
File "test_single_pair.py", line 46, in make_matching_figure
H, _ = cv2.findHomography(mkpts0,mkpts1, cv2.RANSAC, 2.0)
cv2.error: OpenCV(4.4.0) /tmp/pip-req-build-99ib2vsi/opencv/modules/calib3d/src/ptsetreg.cpp:174: error: (-215:Assertion failed) count >= 0 && count2 == count in function 'run'
原因 cv2.findHomography(mkpts0,mkpts1, cv2.RANSAC, 2.0)
输入的点少于四个
加入判断:
if len(mkpts0)>3:
H, _ = cv2.findHomography(mkpts0,mkpts1, cv2.RANSAC, 2.0)