利用影像GPS/POS信息对影像剔除影像匹配粗差
之前写过一篇:读取colmap中database.db里的匹配结果,并且写入images.txt和point3D.txt
匹配是由colmap中的sift算法完成,但实际上,如果具有影像的位置和姿态信息,可以对影像匹配结果进行约束,具体可以设置地面高度,将左影像上一点(x1,y1)投影至地面,再投影至右影像(x2_,y2_),投影点与右影像上的匹配点(x2,y2)的距离超出某一阈值,则认为该点为误匹配。
对上次的代码进行修改,同样是从database读取匹配,并且写入point3d.txt和images.txt,主要新增了从影像中读取GPS/POS信息,以及基于GPS/POS滤除误匹配的功能:
import os
import argparse
import sqlite3
import cv2
import math
import numpy as np
import pyexiv2
from pyproj import Proj,transform
from PIL import Image
from PIL.ExifTags import TAGS, GPSTAGS
class FeaturePoint:
def __init__(self, x, y, point3D_id=-1):
self.x = x
self.y = y
self.point3D_id = point3D_id
class Observation:
def __init__(self, image_id, point2D_idx):
self.image_id = image_id
self.point2D_idx = point2D_idx
class Point3D:
def __init__(self, x=0, y=0, z=0):
self.x = x
self.y = y
self.z = z
self.observations = []
def add_observation(self, observation):
self.observations.append(observation)
def pair_id_to_image_ids(pair_id):
image_id2 = pair_id % 2147483647
image_id1 = (pair_id - image_id2) / 2147483647
return image_id1, image_id2
def get_keypoints(cursor, image_id):
cursor.execute("SELECT * FROM keypoints WHERE image_id = ?;", (image_id,))
_, n_rows, n_columns, raw_data = cursor.fetchone()
keypoints = np.frombuffer(raw_data, dtype=np.float32).reshape(n_rows, n_columns)
return keypoints[:, :2] # Assuming first two columns are x, y coordinates
class ImageMetadata:
def __init__(self, name,width,height,focal_length_pixels,latitude, longitude, altitude, pitch, roll, yaw,datum='wgs84', zone_width=3, projection='gauss'):
self.name=name
self.width=width
self.height=height
self.f=focal_length_pixels
self.latitude = latitude
self.longitude = longitude
self.altitude = altitude
self.pitch = pitch
self.roll = roll
self.yaw = yaw
self.Xs,self.Ys = self.calculate_projection_coordinates(datum, zone_width, projection)
self.Zs=altitude
self.rotation_matrix = gimbal2nPOK(roll, pitch, yaw)
def calculate_projection_coordinates(self, datum, zone_width, projection):
"""
根据经纬度计算投影坐标
"""
easting, northing = latlon_to_custom_projection(self.latitude, self.longitude, datum, zone_width, projection)
return easting, northing
def get_decimal_from_dms(dms, ref):
"""将度、分、秒格式转换为十进制格式"""
degrees, minutes, seconds = dms
decimal = degrees + minutes / 60 + seconds / 3600
if ref in ['S', 'W']:
decimal *= -1
return decimal
# """ def parse_gps_data(metadata):
# """解析GPS数据"""
# lat_dms = [float(x) for x in metadata.get('Exif.GPSInfo.GPSLatitude', '0 0 0').split()]
# lat_ref = metadata.get('Exif.GPSInfo.GPSLatitudeRef', 'N')
# lon_dms = [float(x) for x in metadata.get('Exif.GPSInfo.GPSLongitude', '0 0 0').split()]
# lon_ref = metadata.get('Exif.GPSInfo.GPSLongitudeRef', 'E')
# latitude = get_decimal_from_dms(lat_dms, lat_ref)
# longitude = get_decimal_from_dms(lon_dms, lon_ref)
# altitude = float(metadata.get('Exif.GPSInfo.GPSAltitude', '0'))
# return latitude, longitude, altitude """
def parse_exif_fraction(exif_fraction_str):
"""解析EXIF格式的分数字符串并返回浮点数。"""
numerator, denominator = map(int, exif_fraction_str.split('/'))
return numerator / denominator if denominator != 0 else 0
def latlon_to_custom_projection(latitude, longitude, datum='wgs84', zone_width=3, projection='gauss'):
"""
将经纬度坐标转换为指定的投影坐标
:param latitude: 纬度
:param longitude: 经度
:param datum: 基准('wgs84' 或 'cgcs2000')
:param zone_width: 分带宽度(3 或 6)
:param projection: 投影方式('utm' 或 'gauss')
:return: 投影坐标 (东坐标, 北坐标, 带号)
"""
# 计算分带号
zone_number = int(math.floor((longitude + 180) / zone_width) + 1) if projection == 'utm' else int(math.floor((longitude + 1.5) / zone_width))
if datum == 'wgs84':
ellps = 'WGS84'
elif datum == 'cgcs2000':
ellps = 'GRS80'
else:
raise ValueError("Unsupported datum. Choose 'wgs84' or 'cgcs2000'.")
if projection == 'utm':
proj_str = f'+proj=utm +zone={zone_number} +ellps={ellps} +datum={datum.upper()} +units=m +no_defs'
elif projection == 'gauss':
central_meridian = zone_number * zone_width if zone_width == 3 else (zone_number - 1) * zone_width + 3
proj_str = f'+proj=tmerc +lat_0=0 +lon_0={central_meridian} +k=1 +x_0=500000 +y_0=0 +ellps={ellps} +units=m +no_defs'
else:
raise ValueError("Unsupported projection. Choose 'utm' or 'gauss'.")
# 创建投影
custom_proj = Proj(proj_str)
# WGS 84坐标系
wgs84 = Proj(init='epsg:4326')
# 坐标转换
easting, northing = transform(wgs84, custom_proj, longitude, latitude)
return easting, northing
def get_focal_length_in_pixels(focal_length_str, focal_length_35mm_str, image_width, image_height):
focal_length_mm = parse_exif_fraction(focal_length_str) # 将字符串转换为浮点数
focal_length_35mm = float(focal_length_35mm_str) # 同上
# 全画幅(35mm)传感器宽度约为36mm
sensor_width_mm_35mm = 36
# 计算裁剪因子(crop factor)
crop_factor = focal_length_35mm / focal_length_mm
# 计算传感器的实际宽度(毫米)
sensor_width_mm = sensor_width_mm_35mm / crop_factor
# 计算焦距(像素)
focal_length_pixels = focal_length_mm * (image_width / sensor_width_mm)
return focal_length_pixels
def parse_gps_pos_data(metadata):
"""解析GPS数据"""
latitude=float(metadata.get('Xmp.drone-dji.GpsLatitude', '0'))
longitude=float(metadata.get('Xmp.drone-dji.GpsLongitude', '0'))
altitude=float(metadata.get('Xmp.drone-dji.AbsoluteAltitude', '0'))
"""解析姿态信息"""
# 字段名称可能需要调整
pitch = float(metadata.get('Xmp.drone-dji.GimbalPitchDegree', '0'))
roll = float(metadata.get('Xmp.drone-dji.GimbalRollDegree', '0'))
yaw = float(metadata.get('Xmp.drone-dji.GimbalYawDegree', '0'))
return latitude,longitude,altitude,pitch, roll, yaw
def read_metadata(image_path,name,datum='wgs84', zone_width=6, projection='utm'):
"""读取图像文件中的元数据"""
with pyexiv2.Image(os.path.join(image_path,name)) as img:
metadata = img.read_exif()
metadata.update(img.read_xmp())
focal_length_mm = metadata['Exif.Photo.FocalLength'] # 实际焦距(毫米)
focal_length_35mm = metadata['Exif.Photo.FocalLengthIn35mmFilm'] # 35mm等效焦距(毫米)
image_width=float(metadata['Exif.Photo.PixelXDimension'])
image_height=float(metadata['Exif.Photo.PixelYDimension'])
focal_length_pixels = get_focal_length_in_pixels(focal_length_mm, focal_length_35mm, image_width, image_height)
#latitude, longitude, altitude = parse_gps_data(metadata)
latitude,longitude,altitude,pitch, roll, yaw = parse_gps_pos_data(metadata)
return ImageMetadata(name,image_width,image_height,focal_length_pixels,latitude, longitude, altitude, pitch, roll, yaw, datum, zone_width, projection)
def gimbal2nPOK(roll, pitch, yaw):
d2r = math.pi / 180 # 度到弧度的转换
roll = roll * d2r
pitch = pitch * d2r
yaw = yaw * d2r
R = [0] * 9 # 初始化旋转矩阵
R[0] = math.cos(roll) * math.cos(yaw) + math.sin(roll) * math.sin(pitch) * math.sin(yaw)
R[1] = math.sin(roll) * math.cos(yaw) - math.cos(roll) * math.sin(pitch) * math.sin(yaw)
R[2] = -math.cos(pitch) * math.sin(yaw)
R[3] = math.sin(roll) * math.sin(pitch) * math.cos(yaw) - math.cos(roll) * math.sin(yaw)
R[4] = -math.cos(roll) * math.sin(pitch) * math.cos(yaw) - math.sin(roll) * math.sin(yaw)
R[5] = -math.cos(pitch) * math.cos(yaw)
R[6] = -math.sin(roll) * math.cos(pitch)
R[7] = math.cos(roll) * math.cos(pitch)
R[8] = -math.sin(pitch)
return R
def phoZ2grd(px, py, gz, xs, ys, zs, rot, focus):
A = rot[0] * px + rot[1] * py - rot[2] * focus
B = rot[3] * px + rot[4] * py - rot[5] * focus
C = rot[6] * px + rot[7] * py - rot[8] * focus
N = 0 if C == 0 else (gz - zs) / C
gx = xs + A * N
gy = ys + B * N
return gx, gy
def grd2pho(gx, gy, gz, focus, x0, y0, z0, rotM):
fm = -focus / (rotM[2] * (gx - x0) + rotM[5] * (gy - y0) + rotM[8] * (gz - z0))
px = (rotM[0] * (gx - x0) + rotM[3] * (gy - y0) + rotM[6] * (gz - z0)) * fm
py = (rotM[1] * (gx - x0) + rotM[4] * (gy - y0) + rotM[7] * (gz - z0)) * fm
return px, py
def calculate_projection_and_compare(x1, y1, x2, y2, img1metadata, img2metadata, gz, threshold):
# 将影像坐标转换为以影像中心为原点的坐标
px1 = x1 - img1metadata.width / 2
py1 = img1metadata.height / 2 - y1
# 从第一张影像上的像素坐标计算地面坐标
gx, gy = phoZ2grd(px1, py1, gz, img1metadata.Xs, img1metadata.Ys, img1metadata.Zs, img1metadata.rotation_matrix, img1metadata.f)
# 使用地面坐标计算第二张影像上的像素坐标
px2, py2 = grd2pho(gx, gy, gz, img2metadata.f, img2metadata.Xs, img2metadata.Ys, img2metadata.Zs, img2metadata.rotation_matrix)
# 将坐标转换回原始影像坐标系
x2_pro = px2 + img2metadata.width / 2
y2_pro = img2metadata.height / 2 - py2
# 判断坐标是否小于0
if x2_pro < 0-threshold or y2_pro < 0-threshold or x2_pro>img2metadata.width+threshold or y2_pro>img2metadata.height+threshold:
return False
# 计算距离并与阈值比较
distance = ((x2_pro - x2) ** 2 + (y2_pro - y2) ** 2) ** 0.5
if distance > threshold:
return False
return True
def process_matches(args,image_metadata_dict,matches, keypoints):
points3Ds = {}
featurePoints = {}
next_point3D_id = 1
for pair_id, match_data in matches.items():
id1, id2 = pair_id_to_image_ids(pair_id)
for idxInA, idxInB in match_data:
keyA = (id1, idxInA)
keyB = (id2, idxInB)
if keyA==(1,11404):
keyA=(1,11404)
if keyA not in featurePoints:
coordsA = keypoints[id1][idxInA]
featurePoints[keyA] = FeaturePoint(coordsA[0], coordsA[1], -1)
if keyB not in featurePoints:
coordsB = keypoints[id2][idxInB]
featurePoints[keyB] = FeaturePoint(coordsB[0], coordsB[1], -1)
fpA = featurePoints[keyA]
fpB = featurePoints[keyB]
if args.filter_with_pos==True:
##用pos作为约束去除gross matches
img1metadata=image_metadata_dict[id1]
img2metadata=image_metadata_dict[id2]
keepPoint = calculate_projection_and_compare(fpA.x,fpA.y,fpB.x,fpB.y,img1metadata,img2metadata,40,300)#后面的两个参数:gz和阈值需手动调整
if keepPoint == True:
if fpA.point3D_id == -1 and fpB.point3D_id == -1:
newPoint3D = Point3D()
point3D_id = next_point3D_id
next_point3D_id += 1
newPoint3D.add_observation(Observation(id1, idxInA))
newPoint3D.add_observation(Observation(id2, idxInB))
points3Ds[point3D_id] = newPoint3D
else:
point3D_id = fpA.point3D_id if fpA.point3D_id != -1 else fpB.point3D_id
point3D = points3Ds[point3D_id]
if fpA.point3D_id == -1:
point3D.add_observation(Observation(id1, idxInA))
if fpB.point3D_id == -1:
point3D.add_observation(Observation(id2, idxInB))
fpA.point3D_id = point3D_id
fpB.point3D_id = point3D_id
return featurePoints,points3Ds
def save_images_txt(img_ids_to_names_dict, keypoints,featurePoints, output_path):
with open(os.path.join(output_path, 'images.txt'), 'w') as f:
f.write(f"# Image list with two lines of data per image:\n")
f.write(f"# IMAGE_ID, QW, QX, QY, QZ, TX, TY, TZ, CAMERA_ID, NAME\n")
f.write(f"# POINTS2D[] as (X, Y, POINT3D_ID)\n")
f.write(f"# Number of images: {len(img_ids_to_names_dict)}, mean observations per image: \n")
for image_id in img_ids_to_names_dict:
# 写入图像信息,假设旋转和平移为0,相机 ID 为 1
f.write(f"{image_id} 0 0 0 0 0 0 0 1 {img_ids_to_names_dict[image_id]}\n")
for idx, (x, y) in enumerate(keypoints[image_id]):
# 写入该图像中的所有二维点
# 检查该特征点是否有匹配的三维点
point3D_id = featurePoints.get((image_id, idx), -1)
if point3D_id != -1:
point3D_id = featurePoints[(image_id, idx)].point3D_id
else:
point3D_id = -1 # 如果没有匹配的三维点,则为 -1
f.write(f"{x} {y} {point3D_id} ")
f.write("\n")
def preload_images(image_path, img_ids_to_names_dict):
images = {}
for image_id, img_name in img_ids_to_names_dict.items():
img_file_name = os.path.join(image_path, img_name)
images[image_id] = load_image(img_file_name)
return images
def save_points3Ds_to_file(image_path,img_ids_to_names_dict,points3Ds,keypoints, output_path):
images = preload_images(image_path, img_ids_to_names_dict)
with open(os.path.join(output_path, 'points3D.txt'), 'w') as f:
f.write(f"# 3D point list with one line of data per point:\n")
f.write(f"# POINT3D_ID, X, Y, Z, R, G, B, ERROR, TRACK[] as (IMAGE_ID, POINT2D_IDX)\n")
f.write(f"# Number of points: {len(points3Ds)}, mean track length: \n")
for point3D_id, point in points3Ds.items():
##三维点颜色
colors = []
for obs in point.observations:
image = images[obs.image_id] # 直接使用预加载的影像
coords = keypoints[obs.image_id][obs.point2D_idx]
color = get_color_from_image(image, coords[0], coords[1])
colors.append(color)
point.color = average_color(colors)
# 写入三维点坐标和颜色
f.write(f"{point3D_id} {point.x} {point.y} {point.z} {int(point.color[0])} {int(point.color[1])} {int(point.color[2])} 0 ")
# 写入观测信息
for obs in point.observations:
f.write(f"{int(obs.image_id)} {obs.point2D_idx} ")
f.write("\n")
def load_image(image_path):
# 使用 OpenCV 加载图像
return cv2.imread(image_path, cv2.IMREAD_COLOR)
def get_color_from_image(image, x, y):
# OpenCV 中图像的坐标顺序是 (y, x)
# 并且颜色顺序是 BGR 而不是 RGB
x, y = int(round(x)), int(round(y))
if x < 0 or y < 0 or x >= image.shape[1] or y >= image.shape[0]:
return None # 或返回默认颜色
# OpenCV 使用 BGR 格式,需要转换为 RGB
b, g, r = image[y, x]
return r, g, b
def average_color(colors):
# 计算颜色列表的平均颜色
if not colors:
return None
avg_color = [sum(col) / len(colors) for col in zip(*colors)]
return tuple(avg_color)
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--database_path', required=True, help='Path to the database')
parser.add_argument('--output_path', required=True, help='Name of the output directory')
parser.add_argument('--filter_with_pos', action='store_true',help="if use pos to filter match")
parser.add_argument('--images_path', required=True)
args = parser.parse_args()
filename_db = args.database_path
print("Opening database: " + filename_db)
if not os.path.exists(filename_db):
print('Error db does not exist!')
exit()
if not os.path.exists(args.output_path):
os.mkdir(args.output_path)
# Connect to the database
connection = sqlite3.connect(args.database_path)
cursor = connection.cursor()
list_image_ids = []
img_ids_to_names_dict = {}
# 存储影像ID和对应的元数据
image_metadata_dict = {}
# Extract image ids and keypoints
cursor.execute('SELECT image_id, name, cameras.width, cameras.height FROM images LEFT JOIN cameras ON images.camera_id == cameras.camera_id;')
for row in cursor:
image_idx, name, width, height = row
list_image_ids.append(image_idx)
img_ids_to_names_dict[image_idx] = name
if args.filter_with_pos==True:
metadata = read_metadata(args.images_path,name,'cgcs2000',3,'gauss')##此处参数根据自己需求进行修改
image_metadata_dict[image_idx] = metadata
num_image_ids = len(list_image_ids)
keypoints = {image_id: get_keypoints(cursor, image_id) for image_id in list_image_ids}
# Extract matches
cursor.execute('SELECT pair_id, rows, cols, data FROM two_view_geometries;')
all_matches = {}
for row in cursor:
pair_id = row[0]
rows = row[1]
cols = row[2]
raw_data = row[3]
if (rows < 5):
continue
matches = np.frombuffer(raw_data, dtype=np.uint32).reshape(rows, cols)
all_matches[pair_id] = matches
# Process matches
featurePoints,points3Ds = process_matches(args,image_metadata_dict,all_matches, keypoints)
cursor.close()
connection.close()
save_images_txt(img_ids_to_names_dict,keypoints, featurePoints, args.output_path)
save_points3Ds_to_file(args.images_path,img_ids_to_names_dict,points3Ds, keypoints,args.output_path)
if __name__ == "__main__":
main()
其中,
read_metadata()函数的投影方式
calculate_projection_and_compare()函数中的最后两个参数
以上两个函数需根据需求和测区情况对参数进行调整。
结果
红线标出则为滤除的误匹配点
(以上代码未提供画出匹配结果的部分代码,不会生成下图)