插件-数据库

由于dataset是只支持同步操作,在性能方面有很大瓶颈,这里提供一个异步协程的插件。

pip install aiomysql
pip install nest_asyncio
import asyncio
import datetime
from types import NoneType

import aiomysql
import logging

import nest_asyncio

# 配置日志记录
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

# 定义类型映射(这个类型映射,是为了防止在进行操作的时候,没有建表,用于进行自动建表)
type_mapping = {
    int: 'INT',
    float: 'FLOAT',
    # str: 'VARCHAR(255)',
    str: 'TEXT',
    bool: 'BOOLEAN',
    datetime.datetime: 'DATETIME',
    datetime.date: 'DATE',
    None: 'TEXT',
    NoneType: 'TEXT'
}


async def create_pool():
    """
    创建数据库连接池
    """
    pool = await aiomysql.create_pool(
        host='localhost',
        port=3306,
        user='root',
        password='123456',
        db='zhihu',
        autocommit=True,  # 自动提交事务
        minsize=1,  # 连接池最小连接数
        maxsize=10  # 连接池最大连接数
    )
    return pool


async def close_pool(pool):
    """
    关闭连接池
    :param pool: 数据库连接池
    """
    pool.close()
    await pool.wait_closed()


async def exists_table(pool, table_name: str):
    """
    检测表是否存在
    :param pool: 数据库连接池
    :param table_name: 数据库表名
    :return: 表是否存在,True 表示存在,False 表示不存在
    """
    # 构建 SQL 检测表是否存在语句
    sql = f'''
        SELECT COUNT(*)
        FROM information_schema.tables
        WHERE table_schema = DATABASE()
        AND table_name = %s
    '''
    logging.info(sql % (table_name,))
    async with pool.acquire() as conn:
        async with conn.cursor() as cur:
            # 执行 SQL 检测表是否存在语句
            await cur.execute(sql, (table_name,))
            # 获取查询结果
            result = await cur.fetchone()
            return result[0] > 0  # 返回表是否存在


async def create_table(pool, table_name: str, data: dict):
    """
    创建表
    :param pool: 数据库连接池
    :param table_name: 数据库表名
    :param data: 数据字典,键为字段名,值为字段类型
    """

    # 自动添加主键 id
    columns = ['id INT AUTO_INCREMENT PRIMARY KEY']

    # 根据 data 字典中的数据类型生成字段定义
    if data:
        for key, value in data.items():
            field_type = type_mapping.get(type(value))
            if field_type:
                columns.append(f"`{key}` {field_type}")
            else:
                raise ValueError(f"Unsupported data type for field '{key}': {type(value)}")
    # 自动添加创建时间和更新时间字段
    columns.append('created_at DATETIME DEFAULT CURRENT_TIMESTAMP')
    columns.append('updated_at DATETIME DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP')

    # 构建 SQL 创建表语句
    sql = f'''
            CREATE TABLE `{table_name}` (
                {', '.join(columns)}
            )
        '''
    logging.info(sql)
    async with pool.acquire() as conn:
        async with conn.cursor() as cur:
            # 执行 SQL 创建表语句
            await cur.execute(sql)


async def drop_table(pool, table_name: str):
    """
    删除表
    :param pool: 数据库连接池
    :param table_name: 数据库表名
    """
    # 构建 SQL 删除表语句
    sql = f'''
        DROP TABLE IF EXISTS `{table_name}`
    '''
    logging.info(sql)
    async with pool.acquire() as conn:
        async with conn.cursor() as cur:
            # 执行 SQL 删除表语句
            await cur.execute(sql)


async def truncate_table(pool, table_name: str):
    """
    截断表
    :param pool: 数据库连接池
    :param table_name: 数据库表名
    """
    # 构建 SQL 截断表语句
    sql = f'''
        TRUNCATE TABLE `{table_name}`
    '''
    logging.info(sql)
    async with pool.acquire() as conn:
        async with conn.cursor() as cur:
            # 执行 SQL 截断表语句
            await cur.execute(sql)


async def update_table(pool, table_name: str, data: dict):
    """
    更新表结构
    :param pool: 数据库连接池
    :param table_name: 数据库表名
    :param data: 数据字典,键为字段名,值为字段类型
    """
    # 保留的字段
    reserved_fields = {'id', 'created_at', 'updated_at'}
    async with pool.acquire() as conn:
        async with conn.cursor() as cur:
            # 获取当前表的字段信息
            await cur.execute(f"DESCRIBE `{table_name}`")
            current_fields = await cur.fetchall()
            current_fields_dict = {field[0]: field[1] for field in current_fields}

            # 构建 ALTER TABLE 语句
            alter_statements = []
            if data:
                for key, value in data.items():
                    field_type = type_mapping.get(type(value))
                    if field_type:
                        if key not in current_fields_dict and key not in reserved_fields:
                            alter_statements.append(f"ADD COLUMN `{key}` {field_type}")
                        elif key in current_fields_dict and current_fields_dict[
                            key] != field_type and key not in reserved_fields:
                            alter_statements.append(f"MODIFY COLUMN `{key}` {field_type}")
                    else:
                        continue
                        # raise ValueError(f"Unsupported data type for field '{key}': {type(value)}")

            if alter_statements:
                # 构建 SQL 更新表语句
                sql = f'''
                    ALTER TABLE `{table_name}`
                    {', '.join(alter_statements)}
                '''
                # 执行 SQL 更新表语句
                logging.info(sql)
                await cur.execute(sql)


async def insert(pool, table_name: str, data: dict):
    """
    插入数据
    :param pool 数据库连接池
    :param table_name 数据库表名
    :param data 数据
    """
    # 表是否存在
    exist = await exists_table(pool=pool, table_name=table_name)
    if not exist:
        # 自动建表
        await create_table(pool=pool, table_name=table_name, data=data)
    # 自动更新表
    await update_table(pool=pool, table_name=table_name, data=data)

    # 从 data 字典中提取字段名和对应的值
    columns = ', '.join([f'`{k}`' for k in data.keys()])
    placeholders = ', '.join(['%s'] * len(data))
    values = tuple(data.values())

    # 构建 SQL 插入语句
    sql = f'''
            INSERT INTO `{table_name}` ({columns})
            VALUES ({placeholders})
        '''
    logging.info(sql % values)
    async with pool.acquire() as conn:
        async with conn.cursor() as cur:
            # 执行 SQL 插入语句
            await cur.execute(sql, values)


async def insert_batch(pool, table_name: str, data_list: list):
    """
    批量插入数据
    :param pool: 数据库连接池
    :param table_name: 数据库表名
    :param data_list: 数据列表,每个元素为一个字典,表示一条记录
    """
    if not data_list:
        return  # 如果没有数据,直接返回
    # 表是否存在
    exist = await exists_table(pool=pool, table_name=table_name)
    if not exist:
        # 自动建表
        await create_table(pool=pool, table_name=table_name, data=data_list[0])
    # 自动更新表
    await update_table(pool=pool, table_name=table_name, data=data_list[0])

    # 获取第一个字典的键作为字段名
    columns = list(data_list[0].keys())
    columns_str = ', '.join([f"`{c}`" for c in columns])

    # 构建插入语句的占位符
    placeholders = ', '.join(['%s'] * len(columns))

    # 构建 SQL 插入语句
    sql = f'''
        INSERT INTO `{table_name}` ({columns_str})
        VALUES ({placeholders})
    '''

    # 提取所有记录的值
    values = [tuple(data.values()) for data in data_list]
    logging.info(sql)
    async with pool.acquire() as conn:
        async with conn.cursor() as cur:
            # 执行批量插入
            await cur.executemany(sql, values)


async def delete(pool, table_name: str, condition: dict):
    """
    条件删除
    :param pool: 数据库连接池
    :param table_name: 数据库表名
    :param condition: 条件字典,键为字段名,值为条件值
    """
    # 构建 WHERE 子句
    where_clause = ' AND '.join([f"`{key}` = %s" for key in condition.keys()])
    where_values = tuple(condition.values())

    # 构建 SQL 删除语句
    sql = f'''
        DELETE FROM `{table_name}`
        WHERE {where_clause}
    '''
    logging.info(sql % (where_values))
    async with pool.acquire() as conn:
        async with conn.cursor() as cur:
            # 执行 SQL 删除语句
            await cur.execute(sql, where_values)


async def delete_by_ids(pool, table_name: str, id_list: list):
    """
    批量删除
    :param pool: 数据库连接池
    :param table_name: 数据库表名
    :param id_list: id 列表,每个元素为一个 id
    """
    if not id_list:
        return  # 如果没有数据,直接返回

    # 构建 SQL 删除语句
    sql = f'''
        DELETE FROM `{table_name}`
        WHERE id IN ({', '.join(['%s'] * len(id_list))})
    '''
    logging.info(sql)
    async with pool.acquire() as conn:
        async with conn.cursor() as cur:
            # 执行批量删除
            await cur.execute(sql, id_list)


async def update(pool, table_name: str, data: dict, condition: dict):
    """
    更新数据
    :param pool: 数据库连接池
    :param table_name: 数据库表名
    :param data: 数据字典,键为字段名,值为要更新的数据
    :param condition: 条件字典,键为字段名,值为条件值
    """
    # 表是否存在
    exist = await exists_table(pool=pool, table_name=table_name)
    if not exist:
        # 自动建表
        await create_table(pool=pool, table_name=table_name, data=data)
    # 自动更新表
    await update_table(pool=pool, table_name=table_name, data=data)

    # 构建 SET 子句
    set_clause = ', '.join([f"`{key}` = %s" for key in data.keys()])
    set_values = tuple(data.values())

    # 构建 WHERE 子句
    where_clause = ' AND '.join([f"`{key}` = %s" for key in condition.keys()])
    where_values = tuple(condition.values())

    # 构建 SQL 更新语句
    sql = f'''
        UPDATE `{table_name}`
        SET {set_clause}
        WHERE {where_clause}
    '''

    # 合并 set_values 和 where_values
    values = set_values + where_values
    logging.info(sql % values)
    async with pool.acquire() as conn:
        async with conn.cursor() as cur:
            # 执行 SQL 更新语句
            await cur.execute(sql, values)


async def update_by_id(pool, table_name: str, data: dict):
    """
    根据 id 更新数据
    :param pool: 数据库连接池
    :param table_name: 数据库表名
    :param data: 数据字典,键为字段名,值为要更新的数据
    """
    # 表是否存在
    exist = await exists_table(pool=pool, table_name=table_name)
    if not exist:
        # 自动建表
        await create_table(pool=pool, table_name=table_name, data=data)
    # 自动更新表
    await update_table(pool=pool, table_name=table_name, data=data)

    # 构建 SET 子句
    set_clause = ', '.join([f"`{key}` = %s" for key in data.keys()])
    set_values = tuple(data.values())

    # 构建 SQL 更新语句
    sql = f'''
        UPDATE `{table_name}`
        SET {set_clause}
        WHERE id = %s
    '''

    # 合并 set_values 和 id
    values = set_values + (data['id'],)
    logging.info(sql % values)
    async with pool.acquire() as conn:
        async with conn.cursor() as cur:
            # 执行 SQL 更新语句
            await cur.execute(sql, values)


async def update_batch(pool, table_name: str, data_list: list):
    """
    批量更新数据
    :param pool: 数据库连接池
    :param table_name: 数据库表名
    :param data_list: 数据列表,每个元素为一个字典,表示一条记录,必须包含 'id' 字段
    """
    if not data_list:
        return  # 如果没有数据,直接返回

    # 表是否存在
    exist = await exists_table(pool=pool, table_name=table_name)
    if not exist:
        # 自动建表
        await create_table(pool=pool, table_name=table_name, data=data_list[0])
    # 自动更新表
    await update_table(pool=pool, table_name=table_name, data=data_list[0])

    # 获取第一个字典的键作为字段名(排除 'id')
    columns = list(data_list[0].keys())
    columns.remove('id')

    # 构建更新语句的 SET 子句
    set_clause = ', '.join([f"`{column}` = %s" for column in columns])

    # 构建 SQL 更新语句
    sql = f'''
        UPDATE `{table_name}`
        SET {set_clause}
        WHERE id = %s
    '''

    # 提取所有记录的值
    values = []
    for data in data_list:
        # 提取字段值(排除 'id')
        field_values = [data[column] for column in columns]
        # 添加 'id' 值
        field_values.append(data['id'])
        # 将值列表转换为元组
        values.append(tuple(field_values))
    logging.info(sql)
    async with pool.acquire() as conn:
        async with conn.cursor() as cur:
            # 执行批量更新
            await cur.executemany(sql, values)


async def select_list(pool, table_name: str, condition: dict = None, columns: list = None, order_by: list = None,
                      order_direction: list = None):
    """
    查询数据
    :param pool: 数据库连接池
    :param table_name: 数据库表名
    :param condition: 条件字典,键为字段名,值为条件值
    :param columns: 要查询的列名列表,默认为所有列
    :return: 查询结果列表
    """
    # 表是否存在
    exist = await exists_table(pool=pool, table_name=table_name)
    if not exist:
        # 自动建表
        await create_table(pool=pool, table_name=table_name, data=condition)
    # 自动更新表
    await asyncio.shield(update_table(pool=pool, table_name=table_name, data=condition))

    # 默认查询所有列
    if columns is None:
        columns = ['*']
        # 构建 SELECT 子句
        select_clause = ', '.join(columns)
    else:
        # 构建 SELECT 子句
        select_clause = ', '.join(f"`{c}`" for c in columns)

    # 构建 WHERE 子句
    where_clause = ''
    where_values = ()
    if condition:
        clause = []
        for key in condition.keys():
            # 如果是字典
            if isinstance(condition[key], dict):
                clause.append(f"`{key}` {list(condition[key].keys())[0]} %s")
            else:
                clause.append(f"`{key}` = %s")
        where_clause = ' AND '.join(clause)
        where_values = tuple()
        for value in condition.values():
            # 如果是字典
            if isinstance(value, dict):
                where_values += tuple(value.values())
            else:
                where_values += (value,)
    # 构建 ORDER BY 子句
    order_by_clause = ''
    if order_by and order_direction:
        if len(order_by) != len(order_direction):
            raise ValueError("order_by 和 order_direction 列表长度必须相同")
        order_by_clause = ', '.join([f"{col} {dir}" for col, dir in zip(order_by, order_direction)])

    # 构建 SQL 查询语句
    sql = f'''
        SELECT {select_clause}
        FROM `{table_name}`
    '''
    if where_clause:
        sql += f' WHERE {where_clause}'
    if order_by_clause:
        sql += f' ORDER BY {order_by_clause}'

    logging.info(sql % where_values)
    async with pool.acquire() as conn:
        async with conn.cursor(aiomysql.DictCursor) as cur:
            # 执行 SQL 查询语句
            await cur.execute(sql, where_values)
            # 获取查询结果
            result = await cur.fetchall()
            return result
async def select_last_one(pool, table_name: str, condition: dict = None):
    """
    查询最新一条数据
    :param pool: 数据库连接池
    :param table_name: 数据库表名
    :param condition: 条件字典,键为字段名,值为条件值
    :param columns: 要查询的列名列表,默认为所有列
    :return: 查询结果列表
    """
    # 表是否存在
    exist = await exists_table(pool=pool, table_name=table_name)
    if not exist:
        # 自动建表
        await create_table(pool=pool, table_name=table_name, data=condition)
    # 自动更新表
    await asyncio.shield(update_table(pool=pool, table_name=table_name, data=condition))

    # 查询所有列
    columns = ['*']
    # 构建 SELECT 子句
    select_clause = ', '.join(columns)

    # 构建 WHERE 子句
    where_clause = ''
    where_values = ()
    if condition:
        clause = []
        for key in condition.keys():
            # 如果是字典
            if isinstance(condition[key], dict):
                clause.append(f"`{key}` {list(condition[key].keys())[0]} %s")
            else:
                clause.append(f"`{key}` = %s")
        where_clause = ' AND '.join(clause)
        where_values = tuple()
        for value in condition.values():
            # 如果是字典
            if isinstance(value, dict):
                where_values += tuple(value.values())
            else:
                where_values += (value,)
    # 构建 SQL 查询语句
    sql = f'''
        SELECT {select_clause}
        FROM `{table_name}`
    '''
    if where_clause:
        sql += f' WHERE {where_clause}'
    sql += f' ORDER BY id desc limit 1'

    logging.info(sql % where_values)
    async with pool.acquire() as conn:
        async with conn.cursor(aiomysql.DictCursor) as cur:
            # 执行 SQL 查询语句
            await cur.execute(sql, where_values)
            # 获取查询结果
            result = await cur.fetchone()
            return result

async def select_page(pool, table_name: str, condition: dict = None, columns: list = None, page: int = 1,
                      page_size: int = 10, order_by: list = None, order_direction: list = None):
    """
    分页查询数据
    :param pool: 数据库连接池
    :param table_name: 数据库表名
    :param condition: 条件字典,键为字段名,值为条件值
    :param columns: 要查询的列名列表,默认为所有列
    :param page: 当前页码,默认为1
    :param page_size: 每页记录数,默认为10
    :return: 查询结果列表
    """
    # 表是否存在
    exist = await exists_table(pool=pool, table_name=table_name)
    if not exist:
        # 自动建表
        await create_table(pool=pool, table_name=table_name, data=condition)
    # 自动更新表
    await asyncio.shield(update_table(pool=pool, table_name=table_name, data=condition))

    # 默认查询所有列
    if columns is None:
        columns = ['*']
        # 构建 SELECT 子句
        select_clause = ', '.join(columns)
    else:
        # 构建 SELECT 子句
        select_clause = ', '.join(f"`{c}`" for c in columns)

    # 构建 WHERE 子句
    where_clause = ''
    where_values = ()
    if condition:
        clause = []
        for key in condition.keys():
            # 如果是字典
            if isinstance(condition[key], dict):
                clause.append(f"`{key}` {list(condition[key].keys())[0]} %s")
            else:
                clause.append(f"`{key}` = %s")
        where_clause = ' AND '.join(clause)
        where_values = tuple()
        for value in condition.values():
            # 如果是字典
            if isinstance(value, dict):
                where_values += tuple(value.values())
            else:
                where_values += (value,)
    # 计算偏移量
    offset = (page - 1) * page_size

    # 构建 ORDER BY 子句
    order_by_clause = ''
    if order_by and order_direction:
        if len(order_by) != len(order_direction):
            raise ValueError("order_by 和 order_direction 列表长度必须相同")
        order_by_clause = ', '.join([f"{col} {dir}" for col, dir in zip(order_by, order_direction)])

    # 构建 SQL 查询语句
    sql = f'''
        SELECT {select_clause}
        FROM `{table_name}`
    '''
    if where_clause:
        sql += f' WHERE {where_clause}'
    if order_by_clause:
        sql += f' ORDER BY {order_by_clause}'

    sql += f' LIMIT {page_size} OFFSET {offset}'
    logging.info(sql % where_values)
    async with pool.acquire() as conn:
        async with conn.cursor(aiomysql.DictCursor) as cur:
            # 执行 SQL 查询语句
            await cur.execute(sql, where_values)
            # 获取查询结果
            result = await cur.fetchall()
            return result


async def select_by_id(pool, table_name: str, id: int):
    """
    根据 id 查询数据
    :param pool: 数据库连接池
    :param table_name: 数据库表名
    :param id: 要查询的记录的 id
    :return: 查询结果字典,如果没有找到则返回 None
    """
    # 表是否存在
    exist = await exists_table(pool=pool, table_name=table_name)
    if not exist:
        # 自动建表
        await create_table(pool=pool, table_name=table_name, data={})
    # 构建 SQL 查询语句
    sql = f'''
        SELECT *
        FROM `{table_name}`
        WHERE id = %s
    '''
    logging.info(sql % id)
    async with pool.acquire() as conn:
        async with conn.cursor(aiomysql.DictCursor) as cur:
            # 执行 SQL 查询语句
            await cur.execute(sql, (id,))
            # 获取查询结果
            result = await cur.fetchone()
            return result


async def select_count(pool, table_name: str, condition: dict = None):
    """
    查询记录数
    :param pool: 数据库连接池
    :param table_name: 数据库表名
    :param condition: 条件字典,键为字段名,值为条件值
    :return: 记录数
    """
    # 表是否存在
    exist = await exists_table(pool=pool, table_name=table_name)
    if not exist:
        # 自动建表
        await create_table(pool=pool, table_name=table_name, data=condition)
    # 自动更新表
    await update_table(pool=pool, table_name=table_name, data=condition)

    # 构建 WHERE 子句
    where_clause = ''
    where_values = ()
    if condition:
        clause = []
        for key in condition.keys():
            # 如果是字典
            if isinstance(condition[key], dict):
                clause.append(f"`{key}` {list(condition[key].keys())[0]} %s")
            else:
                clause.append(f"`{key}` = %s")
        where_clause = ' AND '.join(clause)
        where_values = tuple()
        for value in condition.values():
            # 如果是字典
            if isinstance(value, dict):
                where_values += tuple(value.values())
            else:
                where_values += (value,)

    # 构建 SQL 查询记录数语句
    sql = f'''
        SELECT COUNT(*)
        FROM `{table_name}`
    '''
    if where_clause:
        sql += f' WHERE {where_clause}'
    logging.info(sql % where_values)
    async with pool.acquire() as conn:
        async with conn.cursor() as cur:
            # 执行 SQL 查询记录数语句
            await cur.execute(sql, where_values)
            # 获取查询结果
            result = await cur.fetchone()
            return result[0]  # 返回记录数


async def select_by_sql(pool, sql: str):
    """
    根据原生SQL查询数据
    :param pool: 数据库连接池
    :param sql: 原生SQL查询语句
    :param params: SQL查询参数,用于替换SQL中的占位符
    :return: 查询结果列表
    """

    # 记录SQL语句和参数
    logging.info(f"Executing SQL: {sql}")

    async with pool.acquire() as conn:
        async with conn.cursor(aiomysql.DictCursor) as cur:
            # 执行 SQL 查询语句
            await cur.execute(sql)
            # 获取查询结果
            result = await cur.fetchall()
            return result if result else None


async def main():
    init_data = {
        'username': "张三丰",
        'password': "太极",
        'email': "zhangsan@example.com",
        'age': 30,
        'status': True,
        'birthday': datetime.date.today(),
        'address': "武当山",
        'start': 0,
        'select': 0,
        'order': "真假"
    }

    pool = await create_pool()

    try:
        # 插入数据
        await insert(pool=pool, table_name='user', data=init_data)
        # 批量插入
        await insert_batch(pool=pool, table_name='user', data_list=[init_data])

        # 查询列表
        data_list: list[dict] = await select_list(pool=pool, table_name='user', condition=None, order_by=['id'],
                                                  order_direction=['DESC'])
        # ID查询
        data: dict = await select_by_id(pool=pool, table_name='user', id=1)

        # 并发更新数据
        await asyncio.gather(
            *[update(pool=pool, table_name='user', data=data, condition={"id": data['id']}) for data in data_list])

        # 批量更新
        await update_batch(pool=pool, table_name='user', data_list=data_list)

        # 查询数量
        count = await select_count(pool=pool, table_name='user', condition={"status": True})

        # 分页查询
        page_data = await select_page(pool=pool, table_name='user', condition={"status": True}, columns=None,
                                      page=1, page_size=10, order_by=['id'], order_direction=['DESC'])

        # 并发删除数据
        await asyncio.gather(
            *[delete(pool=pool, table_name='user', condition={"id": data['id']}) for data in data_list])

        # 批量删除
        await delete_by_ids(pool=pool, table_name='user', id_list=[data['id'] for data in data_list])

    finally:
        await close_pool(pool)


if __name__ == '__main__':
    nest_asyncio.apply()
    asyncio.run(main())

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

文子阳

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值