python 高效读取多个geojson 写入一个sq3(Sqlite) 、效率提高90%+

1.问题缘由:

        由于工作需求,需要将多个(总量10G+)geojson文件写入到sq3库,众所周知,sqlite 不支持多线程写入,那该怎么办呢,在网上也查了很多策略,都没有达到立竿见影的效果。于是还是回到写文件的本质:多线程写多文件,就绕开加锁的机制。

2.单线程读取的效果

单线程读写原始26个geojson文件,共294M,耗时:547S

写完的sq3文件大小:73.3M

3.多进程并发

多进程并发读写geojson,生成多个sq3文件,再合并到一个sq3文件耗时:16.5S

4.工具代码:

4.1 rw_data_geojson.py: 读写geojson文件

import os
import json

GEOMETRY = 'geometry'


def read_all_layer(src_path):
    """
    读取geojson 文件,传入读取文件路径,返回dict
    dict 是以layername为key,获取每个layer的dict,子字典的key为要素ID
    :param src_path: 读取geojson文件路径
    :return: 封装dict  {layerName:[id:{要素dict}]}
    """
    filenames = os.listdir(src_path)
    # 过滤出你想要处理的文件,例如只读取.txt文件
    txt_filenames = [f for f in filenames if f.endswith('.geojson')]
    geo_properties_map = {}
    # 循环读取每个文件
    for filename in txt_filenames:
        file_path = os.path.join(src_path, filename)
        with open(file_path, 'r') as file:
            content = file.read()
            geojson_data = json.loads(content)

        features = geojson_data.get('features', [])
        dict = {}
        for feature in features:
            properties = feature.get('properties')
            if GEOMETRY in feature:
                properties[GEOMETRY] = feature.get('geometry')
            dict[properties["id"]] = properties
        layername = filename.replace(".geojson", "")
        geo_properties_map[layername] = dict

    return geo_properties_map


def read_single_layer(geojson_path):
    """
    读取指定geojson 文件,返回dict  dict 是以layername为key,获取每个layer的dict,子字典的key为要素ID
    :param geojson_path: 读取geojson文件
    :return: 封装dict  {layerName:[id:{要素dict}]}
    """
    geo_properties_map = {}
    if not geojson_path.endswith('.geojson'):
        return geo_properties_map

    with open(geojson_path, 'r') as file:
        content = file.read()
        geojson_data = json.loads(content)
        features = geojson_data.get('features', [])
        dict = {}
        for feature in features:
            properties = feature.get('properties')
            if GEOMETRY in feature:
                properties[GEOMETRY] = feature.get('geometry')
            dict[properties["id"]] = properties
        key = os.path.basename(geojson_path).replace(".geojson", "")
        geo_properties_map[key] = dict

    return geo_properties_map


def build_geojson(src_feats, layer_name='', epsg_crs=None):
    """按照图层,格式化成geojson规格"""
    attrs = []
    for attr in [attr for key, attr in src_feats.items()]:
        geos_obj = attr.get(GEOMETRY)
        gjson_dict = {"properties": attr, "type": "Feature"}
        if geos_obj is not None:
            gjson_dict[GEOMETRY] = geos_obj
            del attr[GEOMETRY]
        attrs.append(gjson_dict)

    layer = {"type": "FeatureCollection", "features": attrs}

    if layer_name:
        layer['name'] = layer_name
        if epsg_crs and src_feats and any(GEOMETRY in a for a in attrs):
            if isinstance(epsg_crs, int) or (isinstance(epsg_crs, str) and epsg_crs.isdigit()):
                crs_str = "urn:ogc:def:crs:EPSG::%s" % epsg_crs
            else:
                crs_str = epsg_crs
            layer['crs'] = {"type": "name", "properties": {"name": crs_str}}

    return layer


def write_layer(target_path, layer_name, node_data):
    '''
    按图层写geojson数据到磁盘
    :param target_path: 目标文件目录
    :param layer_name: 目标文件名
    :param node_data: 写入的dict嵌套类型数据{dict:{[id:value]}}
    :return:
    '''
    if not os.path.exists(target_path):
        os.makedirs(target_path)
    with open(target_path + "/" + layer_name + ".geojson", 'w') as f:
        json.dump(node_data, f)

    print(target_path + "/" + layer_name + ".geojson 写入完毕")

4.2 db_sq3_tool.py :处理sq3数据库

import sqlite3
import os
from shapely.geometry import shape
from read_file import rw_data_geojson
import random
import time
import multiprocessing
import datetime


def create_connection(db_file):
    """ 创建与SQLite数据库的连接 """
    conn = None
    try:
        conn = sqlite3.connect(db_file)
        return conn
    except sqlite3.Error as e:
        print(e)
    return conn


def create_table(conn, create_table_sql):
    """ 使用给定的SQL语句创建表 """
    try:
        cursor = conn.cursor()
        cursor.execute(create_table_sql)
        conn.commit()
    except sqlite3.Error as e:
        print(e)


def insert_data(conn, insert_sql, data):
    """ 向数据库插入数据 """
    try:
        cursor = conn.cursor()
        cursor.execute(insert_sql, data)
        conn.commit()
    except sqlite3.Error as e:
        print(e)


def batch_insert_data(conn, data_list, table_name, columns):
    '''
    批量插入数据
    :param conn: 数据库连接
    :param data_list: 插入数据list
    :param table_name: 表名
    :param columns: 表的列名list
    :return:
    '''
    cursor = conn.cursor()
    # 构建插入语句的占位符
    placeholders = ', '.join(['?'] * len(columns))
    insert_sql = f"INSERT INTO {table_name} ({', '.join(columns)}) VALUES ({placeholders})"

    try:
        cursor.executemany(insert_sql, data_list)
        conn.commit()
    except sqlite3.Error as e:
        print(f"An error occurred: {e}")
        conn.rollback()


def select_data(conn, select_sql):
    """ 从数据库查询数据 """
    try:
        cursor = conn.cursor()
        cursor.execute(select_sql)
        rows = cursor.fetchall()
        return rows
    except sqlite3.Error as e:
        print(e)


def update_data(conn, update_sql, data):
    """ 更新数据库中的数据 """
    try:
        cursor = conn.cursor()
        cursor.execute(update_sql, data)
        conn.commit()
    except sqlite3.Error as e:
        print(e)


def delete_data(conn, delete_sql, data):
    """ 从数据库中删除数据 """
    try:
        cursor = conn.cursor()
        cursor.execute(delete_sql, data)
        conn.commit()
    except sqlite3.Error as e:
        print(e)


def dict_data_write_sqlite(node_data, table_name, conn, batch=1):
    '''
    将读完的dict 结构图层内容写入到sq3
    :param node_data: 要写入的数据dict
    :param table_name: 表名称
    :param conn: 数据库连接
    :param batch: 是否批量插入
    :return:
    '''
    try:
        if len(node_data) == 0:
            return
        # 获取第一行的key转保存表的列名
        random_key, random_value = random.choice(list(node_data.items()))
        row0 = random_value
        flag2type = {'str': 'TEXT', 'int': 'BIGINT', 'float': 'REAL', 'dict': 'TEXT'}
        fld_types = []
        columns = []
        for key, value in row0.items():
            value_type = type(value).__name__
            # print(f'{table_name} key:{key} 类型: {value_type}')
            fld_types.append((key, flag2type[value_type]))
            columns.append(key)

        fld_sql = ','.join(f'{fld} {typ}' for fld, typ in fld_types if fld != 'id')
        pk_sql = 'id BIGINT PRIMARY KEY'
        create_tab_sql = f'CREATE TABLE IF NOT EXISTS {table_name} ({pk_sql}, {fld_sql});'

        if conn is not None:
            # 1.创建表结构
            create_table(conn, create_tab_sql)

            if batch:
                # 方式1:批量插入,一次提交,效率高
                data_list = []
                for id, data in node_data.items():
                    feature_list = []
                    for key, value in data.items():
                        if 'geometry' == key:
                            geometry = shape(value)
                            feature_list.append(str(geometry.wkt))
                        else:
                            feature_list.append(str(value))
                    data_list.append(feature_list)

                batch_insert_data(conn, data_list, table_name, columns)
            else:
                # 方式2:一条一条插入,适合小数据,效率低下
                for id, data in node_data.items():
                    # 插入数据的SQL语句和数据
                    cur_values = []
                    for key, value in data.items():
                        if 'geometry' == key:
                            geometry = shape(value)
                            cur_values.append("'" + str(geometry.wkt) + "'")
                        else:
                            cur_values.append(str(value))
                    flds_str = ','.join(columns)
                    vals_str = ','.join(cur_values)
                    insert_sql = f"insert into {table_name} ({flds_str}) values ({vals_str})"
                    insert_data(conn, insert_sql, data)

        print(f"{table_name} sq3写入成功")

    except Exception as e:
        print(" 写入sq3异常: " + e)


def read_single_geojson_write_sq3(args):
    '''
    单文件读写sq3
    :param args:
    :return:
    '''
    file_name, target_path = args
    layer_name = file_name.replace(".geojson", "")
    db_file = target_path + '/' + layer_name + '.sq3'
    # 创建数据库连接
    conn = create_connection(db_file)
    # 读取数据,写数据库
    geojson_file = os.path.join(folder_path, file_name)
    node_data = rw_data_geojson.read_single_layer(geojson_file)
    if len(node_data[layer_name]) > 0:
        for layer_name, layer_value in node_data.items():
            dict_data_write_sqlite(layer_value, layer_name, conn, batch=1)
    else:
        # 删除空sq3
        os.remove(db_file)
    conn.close()
    return file_name


def merge_sq3(target_path):
    # 连接到目标数据库(要拷贝到的数据库)
    localtime = time.localtime()
    merge_folder = target_path + "/" + "merge_sq3_finish"
    if not os.path.exists(merge_folder):
        os.makedirs(merge_folder)
    target_db = merge_folder + "/" + str(time.strftime('%Y%m%d', localtime)) + ".sq3"
    if os.path.exists(target_db):
        os.remove(target_db)
        print(f"{target_db} 已被删除。")
    target_conn = create_connection(target_db)
    target_cursor = target_conn.cursor()
    # 连接到源数据库(要拷贝的数据库)
    for item in os.listdir(target_path):
        table_name = item.replace(".sq3", "")
        source_db_file = os.path.join(target_path, item)
        if os.path.isfile(source_db_file) and item != '.DS_Store':
            # 附加源数据库到目标数据库连接
            target_cursor.execute(f"ATTACH DATABASE '{source_db_file}' AS source_db;")
            # 将源sq3中的 table_name 表 复制到 目标.sq3
            target_cursor.execute(f"CREATE TABLE {table_name} AS SELECT * FROM source_db.{table_name}")
            # 分离附加的数据库
            target_cursor.execute("DETACH DATABASE source_db;")
            target_conn.commit()

    # 提交更改并关闭连接
    target_conn.close()

5.单线程读写代码

    folder_path = '/Users/admin/Desktop/123/sq3效率/geojson'
    target_path = "/Users/admin/Desktop/123/sq3效率/merge_sq3"

    # 1.单线程全量读写
    start_time = time.time()
    node_data = rw_data_geojson.read_all_layer(folder_path)

    # 创建数据库连接
    db_file = target_path + '/' + '20240928.sq3'
    if os.path.exists(db_file):
        os.remove(db_file)
        print(f"{db_file} 已被删除。")
    conn = create_connection(db_file)
    for layer_name, layer_value in node_data.items():
        if len(node_data[layer_name]) > 0:
            dict_data_write_sqlite(layer_value, layer_name, conn, batch=0)
    end_time = time.time()
    execution_time = end_time - start_time
    print(f"写入sq3 函数执行时间:{execution_time} 秒")
    exit()

6.多线程读写,合并到一个sq3数据库

# 2.多文件多线程读写
    start_time = time.time()
    for root, dirs, files in os.walk(target_path):
        for file in files:
            db_file = os.path.join(root, file)
            os.remove(db_file)
            print(f"{db_file} 已被删除。")
    with multiprocessing.Pool(processes=5) as pool:
        for file_name in os.listdir(folder_path):
            if file_name == '.DS_Store':
                continue
            params = [(file_name, target_path)]
            pool.map(read_single_geojson_write_sq3, params)

    # 合并多个sq3文件
    merge_sq3(target_path)
    end_time = time.time()
    execution_time = end_time - start_time
    print(f"写入sq3 函数执行时间:{execution_time} 秒")
    exit()

6.在上述基础上,再继续提效

        若单个geojson文件太大时,可多线程分批读取,将读取的块内容,写到一个分块的.sq3,再并发合并到单个图层的sq3,最后将多个图层合并到一个sq3中。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值