由于dataset是只支持同步操作,在性能方面有很大瓶颈,这里提供一个异步协程的插件。
pip install aiomysql
pip install nest_asyncio
import asyncio
import datetime
from types import NoneType
import aiomysql
import logging
import nest_asyncio
# 配置日志记录
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
# 定义类型映射(这个类型映射,是为了防止在进行操作的时候,没有建表,用于进行自动建表)
type_mapping = {
int: 'INT',
float: 'FLOAT',
# str: 'VARCHAR(255)',
str: 'TEXT',
bool: 'BOOLEAN',
datetime.datetime: 'DATETIME',
datetime.date: 'DATE',
None: 'TEXT',
NoneType: 'TEXT'
}
async def create_pool():
"""
创建数据库连接池
"""
pool = await aiomysql.create_pool(
host='localhost',
port=3306,
user='root',
password='123456',
db='zhihu',
autocommit=True, # 自动提交事务
minsize=1, # 连接池最小连接数
maxsize=10 # 连接池最大连接数
)
return pool
async def close_pool(pool):
"""
关闭连接池
:param pool: 数据库连接池
"""
pool.close()
await pool.wait_closed()
async def exists_table(pool, table_name: str):
"""
检测表是否存在
:param pool: 数据库连接池
:param table_name: 数据库表名
:return: 表是否存在,True 表示存在,False 表示不存在
"""
# 构建 SQL 检测表是否存在语句
sql = f'''
SELECT COUNT(*)
FROM information_schema.tables
WHERE table_schema = DATABASE()
AND table_name = %s
'''
logging.info(sql % (table_name,))
async with pool.acquire() as conn:
async with conn.cursor() as cur:
# 执行 SQL 检测表是否存在语句
await cur.execute(sql, (table_name,))
# 获取查询结果
result = await cur.fetchone()
return result[0] > 0 # 返回表是否存在
async def create_table(pool, table_name: str, data: dict):
"""
创建表
:param pool: 数据库连接池
:param table_name: 数据库表名
:param data: 数据字典,键为字段名,值为字段类型
"""
# 自动添加主键 id
columns = ['id INT AUTO_INCREMENT PRIMARY KEY']
# 根据 data 字典中的数据类型生成字段定义
if data:
for key, value in data.items():
field_type = type_mapping.get(type(value))
if field_type:
columns.append(f"`{key}` {field_type}")
else:
raise ValueError(f"Unsupported data type for field '{key}': {type(value)}")
# 自动添加创建时间和更新时间字段
columns.append('created_at DATETIME DEFAULT CURRENT_TIMESTAMP')
columns.append('updated_at DATETIME DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP')
# 构建 SQL 创建表语句
sql = f'''
CREATE TABLE `{table_name}` (
{', '.join(columns)}
)
'''
logging.info(sql)
async with pool.acquire() as conn:
async with conn.cursor() as cur:
# 执行 SQL 创建表语句
await cur.execute(sql)
async def drop_table(pool, table_name: str):
"""
删除表
:param pool: 数据库连接池
:param table_name: 数据库表名
"""
# 构建 SQL 删除表语句
sql = f'''
DROP TABLE IF EXISTS `{table_name}`
'''
logging.info(sql)
async with pool.acquire() as conn:
async with conn.cursor() as cur:
# 执行 SQL 删除表语句
await cur.execute(sql)
async def truncate_table(pool, table_name: str):
"""
截断表
:param pool: 数据库连接池
:param table_name: 数据库表名
"""
# 构建 SQL 截断表语句
sql = f'''
TRUNCATE TABLE `{table_name}`
'''
logging.info(sql)
async with pool.acquire() as conn:
async with conn.cursor() as cur:
# 执行 SQL 截断表语句
await cur.execute(sql)
async def update_table(pool, table_name: str, data: dict):
"""
更新表结构
:param pool: 数据库连接池
:param table_name: 数据库表名
:param data: 数据字典,键为字段名,值为字段类型
"""
# 保留的字段
reserved_fields = {'id', 'created_at', 'updated_at'}
async with pool.acquire() as conn:
async with conn.cursor() as cur:
# 获取当前表的字段信息
await cur.execute(f"DESCRIBE `{table_name}`")
current_fields = await cur.fetchall()
current_fields_dict = {field[0]: field[1] for field in current_fields}
# 构建 ALTER TABLE 语句
alter_statements = []
if data:
for key, value in data.items():
field_type = type_mapping.get(type(value))
if field_type:
if key not in current_fields_dict and key not in reserved_fields:
alter_statements.append(f"ADD COLUMN `{key}` {field_type}")
elif key in current_fields_dict and current_fields_dict[
key] != field_type and key not in reserved_fields:
alter_statements.append(f"MODIFY COLUMN `{key}` {field_type}")
else:
continue
# raise ValueError(f"Unsupported data type for field '{key}': {type(value)}")
if alter_statements:
# 构建 SQL 更新表语句
sql = f'''
ALTER TABLE `{table_name}`
{', '.join(alter_statements)}
'''
# 执行 SQL 更新表语句
logging.info(sql)
await cur.execute(sql)
async def insert(pool, table_name: str, data: dict):
"""
插入数据
:param pool 数据库连接池
:param table_name 数据库表名
:param data 数据
"""
# 表是否存在
exist = await exists_table(pool=pool, table_name=table_name)
if not exist:
# 自动建表
await create_table(pool=pool, table_name=table_name, data=data)
# 自动更新表
await update_table(pool=pool, table_name=table_name, data=data)
# 从 data 字典中提取字段名和对应的值
columns = ', '.join([f'`{k}`' for k in data.keys()])
placeholders = ', '.join(['%s'] * len(data))
values = tuple(data.values())
# 构建 SQL 插入语句
sql = f'''
INSERT INTO `{table_name}` ({columns})
VALUES ({placeholders})
'''
logging.info(sql % values)
async with pool.acquire() as conn:
async with conn.cursor() as cur:
# 执行 SQL 插入语句
await cur.execute(sql, values)
async def insert_batch(pool, table_name: str, data_list: list):
"""
批量插入数据
:param pool: 数据库连接池
:param table_name: 数据库表名
:param data_list: 数据列表,每个元素为一个字典,表示一条记录
"""
if not data_list:
return # 如果没有数据,直接返回
# 表是否存在
exist = await exists_table(pool=pool, table_name=table_name)
if not exist:
# 自动建表
await create_table(pool=pool, table_name=table_name, data=data_list[0])
# 自动更新表
await update_table(pool=pool, table_name=table_name, data=data_list[0])
# 获取第一个字典的键作为字段名
columns = list(data_list[0].keys())
columns_str = ', '.join([f"`{c}`" for c in columns])
# 构建插入语句的占位符
placeholders = ', '.join(['%s'] * len(columns))
# 构建 SQL 插入语句
sql = f'''
INSERT INTO `{table_name}` ({columns_str})
VALUES ({placeholders})
'''
# 提取所有记录的值
values = [tuple(data.values()) for data in data_list]
logging.info(sql)
async with pool.acquire() as conn:
async with conn.cursor() as cur:
# 执行批量插入
await cur.executemany(sql, values)
async def delete(pool, table_name: str, condition: dict):
"""
条件删除
:param pool: 数据库连接池
:param table_name: 数据库表名
:param condition: 条件字典,键为字段名,值为条件值
"""
# 构建 WHERE 子句
where_clause = ' AND '.join([f"`{key}` = %s" for key in condition.keys()])
where_values = tuple(condition.values())
# 构建 SQL 删除语句
sql = f'''
DELETE FROM `{table_name}`
WHERE {where_clause}
'''
logging.info(sql % (where_values))
async with pool.acquire() as conn:
async with conn.cursor() as cur:
# 执行 SQL 删除语句
await cur.execute(sql, where_values)
async def delete_by_ids(pool, table_name: str, id_list: list):
"""
批量删除
:param pool: 数据库连接池
:param table_name: 数据库表名
:param id_list: id 列表,每个元素为一个 id
"""
if not id_list:
return # 如果没有数据,直接返回
# 构建 SQL 删除语句
sql = f'''
DELETE FROM `{table_name}`
WHERE id IN ({', '.join(['%s'] * len(id_list))})
'''
logging.info(sql)
async with pool.acquire() as conn:
async with conn.cursor() as cur:
# 执行批量删除
await cur.execute(sql, id_list)
async def update(pool, table_name: str, data: dict, condition: dict):
"""
更新数据
:param pool: 数据库连接池
:param table_name: 数据库表名
:param data: 数据字典,键为字段名,值为要更新的数据
:param condition: 条件字典,键为字段名,值为条件值
"""
# 表是否存在
exist = await exists_table(pool=pool, table_name=table_name)
if not exist:
# 自动建表
await create_table(pool=pool, table_name=table_name, data=data)
# 自动更新表
await update_table(pool=pool, table_name=table_name, data=data)
# 构建 SET 子句
set_clause = ', '.join([f"`{key}` = %s" for key in data.keys()])
set_values = tuple(data.values())
# 构建 WHERE 子句
where_clause = ' AND '.join([f"`{key}` = %s" for key in condition.keys()])
where_values = tuple(condition.values())
# 构建 SQL 更新语句
sql = f'''
UPDATE `{table_name}`
SET {set_clause}
WHERE {where_clause}
'''
# 合并 set_values 和 where_values
values = set_values + where_values
logging.info(sql % values)
async with pool.acquire() as conn:
async with conn.cursor() as cur:
# 执行 SQL 更新语句
await cur.execute(sql, values)
async def update_by_id(pool, table_name: str, data: dict):
"""
根据 id 更新数据
:param pool: 数据库连接池
:param table_name: 数据库表名
:param data: 数据字典,键为字段名,值为要更新的数据
"""
# 表是否存在
exist = await exists_table(pool=pool, table_name=table_name)
if not exist:
# 自动建表
await create_table(pool=pool, table_name=table_name, data=data)
# 自动更新表
await update_table(pool=pool, table_name=table_name, data=data)
# 构建 SET 子句
set_clause = ', '.join([f"`{key}` = %s" for key in data.keys()])
set_values = tuple(data.values())
# 构建 SQL 更新语句
sql = f'''
UPDATE `{table_name}`
SET {set_clause}
WHERE id = %s
'''
# 合并 set_values 和 id
values = set_values + (data['id'],)
logging.info(sql % values)
async with pool.acquire() as conn:
async with conn.cursor() as cur:
# 执行 SQL 更新语句
await cur.execute(sql, values)
async def update_batch(pool, table_name: str, data_list: list):
"""
批量更新数据
:param pool: 数据库连接池
:param table_name: 数据库表名
:param data_list: 数据列表,每个元素为一个字典,表示一条记录,必须包含 'id' 字段
"""
if not data_list:
return # 如果没有数据,直接返回
# 表是否存在
exist = await exists_table(pool=pool, table_name=table_name)
if not exist:
# 自动建表
await create_table(pool=pool, table_name=table_name, data=data_list[0])
# 自动更新表
await update_table(pool=pool, table_name=table_name, data=data_list[0])
# 获取第一个字典的键作为字段名(排除 'id')
columns = list(data_list[0].keys())
columns.remove('id')
# 构建更新语句的 SET 子句
set_clause = ', '.join([f"`{column}` = %s" for column in columns])
# 构建 SQL 更新语句
sql = f'''
UPDATE `{table_name}`
SET {set_clause}
WHERE id = %s
'''
# 提取所有记录的值
values = []
for data in data_list:
# 提取字段值(排除 'id')
field_values = [data[column] for column in columns]
# 添加 'id' 值
field_values.append(data['id'])
# 将值列表转换为元组
values.append(tuple(field_values))
logging.info(sql)
async with pool.acquire() as conn:
async with conn.cursor() as cur:
# 执行批量更新
await cur.executemany(sql, values)
async def select_list(pool, table_name: str, condition: dict = None, columns: list = None, order_by: list = None,
order_direction: list = None):
"""
查询数据
:param pool: 数据库连接池
:param table_name: 数据库表名
:param condition: 条件字典,键为字段名,值为条件值
:param columns: 要查询的列名列表,默认为所有列
:return: 查询结果列表
"""
# 表是否存在
exist = await exists_table(pool=pool, table_name=table_name)
if not exist:
# 自动建表
await create_table(pool=pool, table_name=table_name, data=condition)
# 自动更新表
await asyncio.shield(update_table(pool=pool, table_name=table_name, data=condition))
# 默认查询所有列
if columns is None:
columns = ['*']
# 构建 SELECT 子句
select_clause = ', '.join(columns)
else:
# 构建 SELECT 子句
select_clause = ', '.join(f"`{c}`" for c in columns)
# 构建 WHERE 子句
where_clause = ''
where_values = ()
if condition:
clause = []
for key in condition.keys():
# 如果是字典
if isinstance(condition[key], dict):
clause.append(f"`{key}` {list(condition[key].keys())[0]} %s")
else:
clause.append(f"`{key}` = %s")
where_clause = ' AND '.join(clause)
where_values = tuple()
for value in condition.values():
# 如果是字典
if isinstance(value, dict):
where_values += tuple(value.values())
else:
where_values += (value,)
# 构建 ORDER BY 子句
order_by_clause = ''
if order_by and order_direction:
if len(order_by) != len(order_direction):
raise ValueError("order_by 和 order_direction 列表长度必须相同")
order_by_clause = ', '.join([f"{col} {dir}" for col, dir in zip(order_by, order_direction)])
# 构建 SQL 查询语句
sql = f'''
SELECT {select_clause}
FROM `{table_name}`
'''
if where_clause:
sql += f' WHERE {where_clause}'
if order_by_clause:
sql += f' ORDER BY {order_by_clause}'
logging.info(sql % where_values)
async with pool.acquire() as conn:
async with conn.cursor(aiomysql.DictCursor) as cur:
# 执行 SQL 查询语句
await cur.execute(sql, where_values)
# 获取查询结果
result = await cur.fetchall()
return result
async def select_last_one(pool, table_name: str, condition: dict = None):
"""
查询最新一条数据
:param pool: 数据库连接池
:param table_name: 数据库表名
:param condition: 条件字典,键为字段名,值为条件值
:param columns: 要查询的列名列表,默认为所有列
:return: 查询结果列表
"""
# 表是否存在
exist = await exists_table(pool=pool, table_name=table_name)
if not exist:
# 自动建表
await create_table(pool=pool, table_name=table_name, data=condition)
# 自动更新表
await asyncio.shield(update_table(pool=pool, table_name=table_name, data=condition))
# 查询所有列
columns = ['*']
# 构建 SELECT 子句
select_clause = ', '.join(columns)
# 构建 WHERE 子句
where_clause = ''
where_values = ()
if condition:
clause = []
for key in condition.keys():
# 如果是字典
if isinstance(condition[key], dict):
clause.append(f"`{key}` {list(condition[key].keys())[0]} %s")
else:
clause.append(f"`{key}` = %s")
where_clause = ' AND '.join(clause)
where_values = tuple()
for value in condition.values():
# 如果是字典
if isinstance(value, dict):
where_values += tuple(value.values())
else:
where_values += (value,)
# 构建 SQL 查询语句
sql = f'''
SELECT {select_clause}
FROM `{table_name}`
'''
if where_clause:
sql += f' WHERE {where_clause}'
sql += f' ORDER BY id desc limit 1'
logging.info(sql % where_values)
async with pool.acquire() as conn:
async with conn.cursor(aiomysql.DictCursor) as cur:
# 执行 SQL 查询语句
await cur.execute(sql, where_values)
# 获取查询结果
result = await cur.fetchone()
return result
async def select_page(pool, table_name: str, condition: dict = None, columns: list = None, page: int = 1,
page_size: int = 10, order_by: list = None, order_direction: list = None):
"""
分页查询数据
:param pool: 数据库连接池
:param table_name: 数据库表名
:param condition: 条件字典,键为字段名,值为条件值
:param columns: 要查询的列名列表,默认为所有列
:param page: 当前页码,默认为1
:param page_size: 每页记录数,默认为10
:return: 查询结果列表
"""
# 表是否存在
exist = await exists_table(pool=pool, table_name=table_name)
if not exist:
# 自动建表
await create_table(pool=pool, table_name=table_name, data=condition)
# 自动更新表
await asyncio.shield(update_table(pool=pool, table_name=table_name, data=condition))
# 默认查询所有列
if columns is None:
columns = ['*']
# 构建 SELECT 子句
select_clause = ', '.join(columns)
else:
# 构建 SELECT 子句
select_clause = ', '.join(f"`{c}`" for c in columns)
# 构建 WHERE 子句
where_clause = ''
where_values = ()
if condition:
clause = []
for key in condition.keys():
# 如果是字典
if isinstance(condition[key], dict):
clause.append(f"`{key}` {list(condition[key].keys())[0]} %s")
else:
clause.append(f"`{key}` = %s")
where_clause = ' AND '.join(clause)
where_values = tuple()
for value in condition.values():
# 如果是字典
if isinstance(value, dict):
where_values += tuple(value.values())
else:
where_values += (value,)
# 计算偏移量
offset = (page - 1) * page_size
# 构建 ORDER BY 子句
order_by_clause = ''
if order_by and order_direction:
if len(order_by) != len(order_direction):
raise ValueError("order_by 和 order_direction 列表长度必须相同")
order_by_clause = ', '.join([f"{col} {dir}" for col, dir in zip(order_by, order_direction)])
# 构建 SQL 查询语句
sql = f'''
SELECT {select_clause}
FROM `{table_name}`
'''
if where_clause:
sql += f' WHERE {where_clause}'
if order_by_clause:
sql += f' ORDER BY {order_by_clause}'
sql += f' LIMIT {page_size} OFFSET {offset}'
logging.info(sql % where_values)
async with pool.acquire() as conn:
async with conn.cursor(aiomysql.DictCursor) as cur:
# 执行 SQL 查询语句
await cur.execute(sql, where_values)
# 获取查询结果
result = await cur.fetchall()
return result
async def select_by_id(pool, table_name: str, id: int):
"""
根据 id 查询数据
:param pool: 数据库连接池
:param table_name: 数据库表名
:param id: 要查询的记录的 id
:return: 查询结果字典,如果没有找到则返回 None
"""
# 表是否存在
exist = await exists_table(pool=pool, table_name=table_name)
if not exist:
# 自动建表
await create_table(pool=pool, table_name=table_name, data={})
# 构建 SQL 查询语句
sql = f'''
SELECT *
FROM `{table_name}`
WHERE id = %s
'''
logging.info(sql % id)
async with pool.acquire() as conn:
async with conn.cursor(aiomysql.DictCursor) as cur:
# 执行 SQL 查询语句
await cur.execute(sql, (id,))
# 获取查询结果
result = await cur.fetchone()
return result
async def select_count(pool, table_name: str, condition: dict = None):
"""
查询记录数
:param pool: 数据库连接池
:param table_name: 数据库表名
:param condition: 条件字典,键为字段名,值为条件值
:return: 记录数
"""
# 表是否存在
exist = await exists_table(pool=pool, table_name=table_name)
if not exist:
# 自动建表
await create_table(pool=pool, table_name=table_name, data=condition)
# 自动更新表
await update_table(pool=pool, table_name=table_name, data=condition)
# 构建 WHERE 子句
where_clause = ''
where_values = ()
if condition:
clause = []
for key in condition.keys():
# 如果是字典
if isinstance(condition[key], dict):
clause.append(f"`{key}` {list(condition[key].keys())[0]} %s")
else:
clause.append(f"`{key}` = %s")
where_clause = ' AND '.join(clause)
where_values = tuple()
for value in condition.values():
# 如果是字典
if isinstance(value, dict):
where_values += tuple(value.values())
else:
where_values += (value,)
# 构建 SQL 查询记录数语句
sql = f'''
SELECT COUNT(*)
FROM `{table_name}`
'''
if where_clause:
sql += f' WHERE {where_clause}'
logging.info(sql % where_values)
async with pool.acquire() as conn:
async with conn.cursor() as cur:
# 执行 SQL 查询记录数语句
await cur.execute(sql, where_values)
# 获取查询结果
result = await cur.fetchone()
return result[0] # 返回记录数
async def select_by_sql(pool, sql: str):
"""
根据原生SQL查询数据
:param pool: 数据库连接池
:param sql: 原生SQL查询语句
:param params: SQL查询参数,用于替换SQL中的占位符
:return: 查询结果列表
"""
# 记录SQL语句和参数
logging.info(f"Executing SQL: {sql}")
async with pool.acquire() as conn:
async with conn.cursor(aiomysql.DictCursor) as cur:
# 执行 SQL 查询语句
await cur.execute(sql)
# 获取查询结果
result = await cur.fetchall()
return result if result else None
async def main():
init_data = {
'username': "张三丰",
'password': "太极",
'email': "zhangsan@example.com",
'age': 30,
'status': True,
'birthday': datetime.date.today(),
'address': "武当山",
'start': 0,
'select': 0,
'order': "真假"
}
pool = await create_pool()
try:
# 插入数据
await insert(pool=pool, table_name='user', data=init_data)
# 批量插入
await insert_batch(pool=pool, table_name='user', data_list=[init_data])
# 查询列表
data_list: list[dict] = await select_list(pool=pool, table_name='user', condition=None, order_by=['id'],
order_direction=['DESC'])
# ID查询
data: dict = await select_by_id(pool=pool, table_name='user', id=1)
# 并发更新数据
await asyncio.gather(
*[update(pool=pool, table_name='user', data=data, condition={"id": data['id']}) for data in data_list])
# 批量更新
await update_batch(pool=pool, table_name='user', data_list=data_list)
# 查询数量
count = await select_count(pool=pool, table_name='user', condition={"status": True})
# 分页查询
page_data = await select_page(pool=pool, table_name='user', condition={"status": True}, columns=None,
page=1, page_size=10, order_by=['id'], order_direction=['DESC'])
# 并发删除数据
await asyncio.gather(
*[delete(pool=pool, table_name='user', condition={"id": data['id']}) for data in data_list])
# 批量删除
await delete_by_ids(pool=pool, table_name='user', id_list=[data['id'] for data in data_list])
finally:
await close_pool(pool)
if __name__ == '__main__':
nest_asyncio.apply()
asyncio.run(main())