利用影像GPS/POS信息对影像剔除影像匹配粗差

利用影像GPS/POS信息对影像剔除影像匹配粗差

之前写过一篇:读取colmap中database.db里的匹配结果,并且写入images.txt和point3D.txt
匹配是由colmap中的sift算法完成,但实际上,如果具有影像的位置和姿态信息,可以对影像匹配结果进行约束,具体可以设置地面高度,将左影像上一点(x1,y1)投影至地面,再投影至右影像(x2_,y2_),投影点与右影像上的匹配点(x2,y2)的距离超出某一阈值,则认为该点为误匹配。

对上次的代码进行修改,同样是从database读取匹配,并且写入point3d.txt和images.txt,主要新增了从影像中读取GPS/POS信息,以及基于GPS/POS滤除误匹配的功能

import os
import argparse
import sqlite3
import cv2
import math
import numpy as np

import pyexiv2
from pyproj import Proj,transform

from PIL import Image
from PIL.ExifTags import TAGS, GPSTAGS

class FeaturePoint:
    def __init__(self, x, y, point3D_id=-1):
        self.x = x
        self.y = y
        self.point3D_id = point3D_id

class Observation:
    def __init__(self, image_id, point2D_idx):
        self.image_id = image_id
        self.point2D_idx = point2D_idx

class Point3D:
    def __init__(self, x=0, y=0, z=0):
        self.x = x
        self.y = y
        self.z = z
        self.observations = []

    def add_observation(self, observation):
        self.observations.append(observation)


def pair_id_to_image_ids(pair_id):
    image_id2 = pair_id % 2147483647
    image_id1 = (pair_id - image_id2) / 2147483647
    return image_id1, image_id2

def get_keypoints(cursor, image_id):
    cursor.execute("SELECT * FROM keypoints WHERE image_id = ?;", (image_id,))        
    _, n_rows, n_columns, raw_data = cursor.fetchone()
    keypoints = np.frombuffer(raw_data, dtype=np.float32).reshape(n_rows, n_columns)
    return keypoints[:, :2]  # Assuming first two columns are x, y coordinates

class ImageMetadata:
    def __init__(self, name,width,height,focal_length_pixels,latitude, longitude, altitude, pitch, roll, yaw,datum='wgs84', zone_width=3, projection='gauss'):
        self.name=name
        self.width=width
        self.height=height
        self.f=focal_length_pixels
        self.latitude = latitude
        self.longitude = longitude
        self.altitude = altitude
        self.pitch = pitch
        self.roll = roll
        self.yaw = yaw
        self.Xs,self.Ys = self.calculate_projection_coordinates(datum, zone_width, projection)
        self.Zs=altitude
        self.rotation_matrix = gimbal2nPOK(roll, pitch, yaw)

    def calculate_projection_coordinates(self, datum, zone_width, projection):
        """
        根据经纬度计算投影坐标
        """
        easting, northing = latlon_to_custom_projection(self.latitude, self.longitude, datum, zone_width, projection)
        return easting, northing

def get_decimal_from_dms(dms, ref):
    """将度、分、秒格式转换为十进制格式"""
    degrees, minutes, seconds = dms
    decimal = degrees + minutes / 60 + seconds / 3600
    if ref in ['S', 'W']:
        decimal *= -1
    return decimal

# """ def parse_gps_data(metadata):
#     """解析GPS数据"""
#     lat_dms = [float(x) for x in metadata.get('Exif.GPSInfo.GPSLatitude', '0 0 0').split()]
#     lat_ref = metadata.get('Exif.GPSInfo.GPSLatitudeRef', 'N')
#     lon_dms = [float(x) for x in metadata.get('Exif.GPSInfo.GPSLongitude', '0 0 0').split()]
#     lon_ref = metadata.get('Exif.GPSInfo.GPSLongitudeRef', 'E')

#     latitude = get_decimal_from_dms(lat_dms, lat_ref)
#     longitude = get_decimal_from_dms(lon_dms, lon_ref)
#     altitude = float(metadata.get('Exif.GPSInfo.GPSAltitude', '0'))

#     return latitude, longitude, altitude """

def parse_exif_fraction(exif_fraction_str):
    """解析EXIF格式的分数字符串并返回浮点数。"""
    numerator, denominator = map(int, exif_fraction_str.split('/'))
    return numerator / denominator if denominator != 0 else 0

def latlon_to_custom_projection(latitude, longitude, datum='wgs84', zone_width=3, projection='gauss'):
    """
    将经纬度坐标转换为指定的投影坐标
    :param latitude: 纬度
    :param longitude: 经度
    :param datum: 基准('wgs84''cgcs2000':param zone_width: 分带宽度(36:param projection: 投影方式('utm''gauss':return: 投影坐标 (东坐标, 北坐标, 带号)
    """
    # 计算分带号
    zone_number = int(math.floor((longitude + 180) / zone_width) + 1) if projection == 'utm' else int(math.floor((longitude + 1.5) / zone_width))

    if datum == 'wgs84':
        ellps = 'WGS84'
    elif datum == 'cgcs2000':
        ellps = 'GRS80'
    else:
        raise ValueError("Unsupported datum. Choose 'wgs84' or 'cgcs2000'.")

    if projection == 'utm':
        proj_str = f'+proj=utm +zone={zone_number} +ellps={ellps} +datum={datum.upper()} +units=m +no_defs'
    elif projection == 'gauss':
        central_meridian = zone_number * zone_width if zone_width == 3 else (zone_number - 1) * zone_width + 3
        proj_str = f'+proj=tmerc +lat_0=0 +lon_0={central_meridian} +k=1 +x_0=500000 +y_0=0 +ellps={ellps} +units=m +no_defs'
    else:
        raise ValueError("Unsupported projection. Choose 'utm' or 'gauss'.")

    # 创建投影
    custom_proj = Proj(proj_str)

    # WGS 84坐标系
    wgs84 = Proj(init='epsg:4326')

    # 坐标转换
    easting, northing = transform(wgs84, custom_proj, longitude, latitude)

    return easting, northing

def get_focal_length_in_pixels(focal_length_str, focal_length_35mm_str, image_width, image_height):
    focal_length_mm = parse_exif_fraction(focal_length_str)  # 将字符串转换为浮点数
    focal_length_35mm = float(focal_length_35mm_str)  # 同上

    # 全画幅(35mm)传感器宽度约为36mm
    sensor_width_mm_35mm = 36

    # 计算裁剪因子(crop factor)
    crop_factor = focal_length_35mm / focal_length_mm

    # 计算传感器的实际宽度(毫米)
    sensor_width_mm = sensor_width_mm_35mm / crop_factor

    # 计算焦距(像素)
    focal_length_pixels = focal_length_mm * (image_width / sensor_width_mm)

    return focal_length_pixels


def parse_gps_pos_data(metadata):
    """解析GPS数据"""
    latitude=float(metadata.get('Xmp.drone-dji.GpsLatitude', '0'))
    longitude=float(metadata.get('Xmp.drone-dji.GpsLongitude', '0'))
    altitude=float(metadata.get('Xmp.drone-dji.AbsoluteAltitude', '0'))
    """解析姿态信息"""
    # 字段名称可能需要调整
    pitch = float(metadata.get('Xmp.drone-dji.GimbalPitchDegree', '0'))
    roll = float(metadata.get('Xmp.drone-dji.GimbalRollDegree', '0'))
    yaw = float(metadata.get('Xmp.drone-dji.GimbalYawDegree', '0'))

    return latitude,longitude,altitude,pitch, roll, yaw



def read_metadata(image_path,name,datum='wgs84', zone_width=6, projection='utm'):
    """读取图像文件中的元数据"""
    with pyexiv2.Image(os.path.join(image_path,name)) as img:
        metadata = img.read_exif()
        metadata.update(img.read_xmp())
        focal_length_mm = metadata['Exif.Photo.FocalLength']  # 实际焦距(毫米)
        focal_length_35mm = metadata['Exif.Photo.FocalLengthIn35mmFilm']  # 35mm等效焦距(毫米)
        image_width=float(metadata['Exif.Photo.PixelXDimension'])
        image_height=float(metadata['Exif.Photo.PixelYDimension'])
        focal_length_pixels = get_focal_length_in_pixels(focal_length_mm, focal_length_35mm, image_width, image_height)
        #latitude, longitude, altitude = parse_gps_data(metadata)
        latitude,longitude,altitude,pitch, roll, yaw = parse_gps_pos_data(metadata)

        return ImageMetadata(name,image_width,image_height,focal_length_pixels,latitude, longitude, altitude, pitch, roll, yaw, datum, zone_width, projection)


def gimbal2nPOK(roll, pitch, yaw):
    d2r = math.pi / 180  # 度到弧度的转换

    roll = roll * d2r
    pitch = pitch * d2r
    yaw = yaw * d2r

    R = [0] * 9  # 初始化旋转矩阵

    R[0] = math.cos(roll) * math.cos(yaw) + math.sin(roll) * math.sin(pitch) * math.sin(yaw)
    R[1] = math.sin(roll) * math.cos(yaw) - math.cos(roll) * math.sin(pitch) * math.sin(yaw)
    R[2] = -math.cos(pitch) * math.sin(yaw)

    R[3] = math.sin(roll) * math.sin(pitch) * math.cos(yaw) - math.cos(roll) * math.sin(yaw)
    R[4] = -math.cos(roll) * math.sin(pitch) * math.cos(yaw) - math.sin(roll) * math.sin(yaw)
    R[5] = -math.cos(pitch) * math.cos(yaw)

    R[6] = -math.sin(roll) * math.cos(pitch)
    R[7] = math.cos(roll) * math.cos(pitch)
    R[8] = -math.sin(pitch)

    return R

def phoZ2grd(px, py, gz, xs, ys, zs, rot, focus):
    A = rot[0] * px + rot[1] * py - rot[2] * focus
    B = rot[3] * px + rot[4] * py - rot[5] * focus
    C = rot[6] * px + rot[7] * py - rot[8] * focus
    N = 0 if C == 0 else (gz - zs) / C
    gx = xs + A * N
    gy = ys + B * N
    return gx, gy

def grd2pho(gx, gy, gz, focus, x0, y0, z0, rotM):
    fm = -focus / (rotM[2] * (gx - x0) + rotM[5] * (gy - y0) + rotM[8] * (gz - z0))
    px = (rotM[0] * (gx - x0) + rotM[3] * (gy - y0) + rotM[6] * (gz - z0)) * fm
    py = (rotM[1] * (gx - x0) + rotM[4] * (gy - y0) + rotM[7] * (gz - z0)) * fm
    return px, py


def calculate_projection_and_compare(x1, y1, x2, y2, img1metadata, img2metadata, gz, threshold):
    # 将影像坐标转换为以影像中心为原点的坐标
    px1 = x1 - img1metadata.width / 2
    py1 = img1metadata.height / 2 - y1

    # 从第一张影像上的像素坐标计算地面坐标
    gx, gy = phoZ2grd(px1, py1, gz, img1metadata.Xs, img1metadata.Ys, img1metadata.Zs, img1metadata.rotation_matrix, img1metadata.f)

    # 使用地面坐标计算第二张影像上的像素坐标
    px2, py2 = grd2pho(gx, gy, gz, img2metadata.f, img2metadata.Xs, img2metadata.Ys, img2metadata.Zs, img2metadata.rotation_matrix)

    # 将坐标转换回原始影像坐标系
    x2_pro = px2 + img2metadata.width / 2
    y2_pro = img2metadata.height / 2 - py2


    # 判断坐标是否小于0
    if x2_pro < 0-threshold or y2_pro < 0-threshold or x2_pro>img2metadata.width+threshold or y2_pro>img2metadata.height+threshold:
        return False
    # 计算距离并与阈值比较
    distance = ((x2_pro - x2) ** 2 + (y2_pro - y2) ** 2) ** 0.5
    if distance > threshold:
        return False

    return True


def process_matches(args,image_metadata_dict,matches, keypoints):
    points3Ds = {}
    featurePoints = {}
    next_point3D_id = 1

    for pair_id, match_data in matches.items():
        id1, id2 = pair_id_to_image_ids(pair_id)
        for idxInA, idxInB in match_data:
            keyA = (id1, idxInA)
            keyB = (id2, idxInB)
            if keyA==(1,11404):
                keyA=(1,11404)
            if keyA not in featurePoints:
                coordsA = keypoints[id1][idxInA]
                featurePoints[keyA] = FeaturePoint(coordsA[0], coordsA[1], -1)
            if keyB not in featurePoints:
                coordsB = keypoints[id2][idxInB]
                featurePoints[keyB] = FeaturePoint(coordsB[0], coordsB[1], -1)

            fpA = featurePoints[keyA]
            fpB = featurePoints[keyB]

            if args.filter_with_pos==True:
                ##用pos作为约束去除gross matches
                img1metadata=image_metadata_dict[id1]
                img2metadata=image_metadata_dict[id2]
                keepPoint = calculate_projection_and_compare(fpA.x,fpA.y,fpB.x,fpB.y,img1metadata,img2metadata,40,300)#后面的两个参数:gz和阈值需手动调整
                if keepPoint == True:
                    if fpA.point3D_id == -1 and fpB.point3D_id == -1:
                        newPoint3D = Point3D()
                        point3D_id = next_point3D_id
                        next_point3D_id += 1
                        newPoint3D.add_observation(Observation(id1, idxInA))
                        newPoint3D.add_observation(Observation(id2, idxInB))
                        points3Ds[point3D_id] = newPoint3D
                    else:
                        point3D_id = fpA.point3D_id if fpA.point3D_id != -1 else fpB.point3D_id
                        point3D = points3Ds[point3D_id]
                        if fpA.point3D_id == -1:
                            point3D.add_observation(Observation(id1, idxInA))
                        if fpB.point3D_id == -1:
                            point3D.add_observation(Observation(id2, idxInB))

                    fpA.point3D_id = point3D_id
                    fpB.point3D_id = point3D_id
                
                
    return featurePoints,points3Ds

def save_images_txt(img_ids_to_names_dict, keypoints,featurePoints, output_path):
    with open(os.path.join(output_path, 'images.txt'), 'w') as f:
        f.write(f"# Image list with two lines of data per image:\n")
        f.write(f"#   IMAGE_ID, QW, QX, QY, QZ, TX, TY, TZ, CAMERA_ID, NAME\n")
        f.write(f"#   POINTS2D[] as (X, Y, POINT3D_ID)\n")
        f.write(f"# Number of images: {len(img_ids_to_names_dict)}, mean observations per image: \n")

        for image_id in img_ids_to_names_dict:
            # 写入图像信息,假设旋转和平移为0,相机 ID 为 1
            f.write(f"{image_id} 0 0 0 0 0 0 0 1 {img_ids_to_names_dict[image_id]}\n")
            for idx, (x, y) in enumerate(keypoints[image_id]):
            # 写入该图像中的所有二维点
            # 检查该特征点是否有匹配的三维点
                point3D_id = featurePoints.get((image_id, idx), -1)
                if point3D_id != -1:
                    point3D_id = featurePoints[(image_id, idx)].point3D_id
                else:
                    point3D_id = -1  # 如果没有匹配的三维点,则为 -1
                f.write(f"{x} {y} {point3D_id} ")
            f.write("\n")


def preload_images(image_path, img_ids_to_names_dict):
    images = {}
    for image_id, img_name in img_ids_to_names_dict.items():
        img_file_name = os.path.join(image_path, img_name)
        images[image_id] = load_image(img_file_name)
    return images


def save_points3Ds_to_file(image_path,img_ids_to_names_dict,points3Ds,keypoints, output_path):
    images = preload_images(image_path, img_ids_to_names_dict)

    with open(os.path.join(output_path, 'points3D.txt'), 'w') as f:
        f.write(f"# 3D point list with one line of data per point:\n")
        f.write(f"#   POINT3D_ID, X, Y, Z, R, G, B, ERROR, TRACK[] as (IMAGE_ID, POINT2D_IDX)\n")
        f.write(f"# Number of points: {len(points3Ds)}, mean track length: \n")
        for point3D_id, point in points3Ds.items():
            ##三维点颜色
            colors = []
            for obs in point.observations:
                image = images[obs.image_id]  # 直接使用预加载的影像
                coords = keypoints[obs.image_id][obs.point2D_idx]
                color = get_color_from_image(image, coords[0], coords[1])
                colors.append(color)
            point.color = average_color(colors)
            # 写入三维点坐标和颜色
            f.write(f"{point3D_id} {point.x} {point.y} {point.z} {int(point.color[0])} {int(point.color[1])} {int(point.color[2])} 0 ")
            # 写入观测信息
            for obs in point.observations:
                f.write(f"{int(obs.image_id)} {obs.point2D_idx} ")
            f.write("\n")


def load_image(image_path):
    # 使用 OpenCV 加载图像
    return cv2.imread(image_path, cv2.IMREAD_COLOR)

def get_color_from_image(image, x, y):
    # OpenCV 中图像的坐标顺序是 (y, x)
    # 并且颜色顺序是 BGR 而不是 RGB
    x, y = int(round(x)), int(round(y))
    if x < 0 or y < 0 or x >= image.shape[1] or y >= image.shape[0]:
        return None  # 或返回默认颜色

    # OpenCV 使用 BGR 格式,需要转换为 RGB
    b, g, r = image[y, x]
    return r, g, b

def average_color(colors):
    # 计算颜色列表的平均颜色
    if not colors:
        return None
    avg_color = [sum(col) / len(colors) for col in zip(*colors)]
    return tuple(avg_color)


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--database_path', required=True, help='Path to the database')
    parser.add_argument('--output_path', required=True, help='Name of the output directory')
    parser.add_argument('--filter_with_pos', action='store_true',help="if use pos to filter match")
    parser.add_argument('--images_path', required=True)
    args = parser.parse_args()

    filename_db = args.database_path

    print("Opening database: " + filename_db)
    if not os.path.exists(filename_db):
        print('Error db does not exist!')
        exit()

    if not os.path.exists(args.output_path):
        os.mkdir(args.output_path)

    # Connect to the database
    connection = sqlite3.connect(args.database_path)
    cursor = connection.cursor()


    list_image_ids = []
    img_ids_to_names_dict = {}
    # 存储影像ID和对应的元数据
    image_metadata_dict = {}
    # Extract image ids and keypoints
    cursor.execute('SELECT image_id, name, cameras.width, cameras.height FROM images LEFT JOIN cameras ON images.camera_id == cameras.camera_id;')
    for row in cursor:
        image_idx, name, width, height = row        
        list_image_ids.append(image_idx)
        img_ids_to_names_dict[image_idx] = name
        if args.filter_with_pos==True:
            metadata = read_metadata(args.images_path,name,'cgcs2000',3,'gauss')##此处参数根据自己需求进行修改
            image_metadata_dict[image_idx] = metadata
    
    num_image_ids = len(list_image_ids)

    keypoints = {image_id: get_keypoints(cursor, image_id) for image_id in list_image_ids}

    # Extract matches
    cursor.execute('SELECT pair_id, rows, cols, data FROM two_view_geometries;')
    all_matches = {}
    for row in cursor:
        pair_id = row[0]
        rows = row[1]
        cols = row[2]
        raw_data = row[3]
        if (rows < 5):
            continue

        matches = np.frombuffer(raw_data, dtype=np.uint32).reshape(rows, cols)
        all_matches[pair_id] = matches

    # Process matches
    featurePoints,points3Ds = process_matches(args,image_metadata_dict,all_matches, keypoints)

    cursor.close()
    connection.close()

    save_images_txt(img_ids_to_names_dict,keypoints, featurePoints, args.output_path)
    save_points3Ds_to_file(args.images_path,img_ids_to_names_dict,points3Ds, keypoints,args.output_path)


if __name__ == "__main__":
    main()

其中,
read_metadata()函数的投影方式
calculate_projection_and_compare()函数中的最后两个参数
以上两个函数需根据需求和测区情况对参数进行调整。

结果

红线标出则为滤除的误匹配点
(以上代码未提供画出匹配结果的部分代码,不会生成下图)
在这里插入图片描述

  • 5
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值