〇、相关资料

1、SQLAlchemy document

 https://docs.sqlalchemy.org/en/20/

2、SQLAlchemy一般方法总结

 https://blog.csdn.net/chehec2010/article/details/122970638

3、FastAPI+SQLModel开发增删改查接口

 https://mp.weixin.qq.com/s/1Hl-UPiZ6ujVkEtF-_l00w


一、SQLAlchemy

1.1 实体类

1、数据库建表语句-pg
drop table if exists t_table_name cascade;

CREATE TABLE t_table_name (
    id BIGSERIAL PRIMARY KEY, -- 使用BIGSERIAL自动生成ID,但您希望ID包含日期时间,这里需要额外处理
    name VARCHAR(255) NOT NULL,
    pt VARCHAR(255) NOT NULL,
    start_time VARCHAR(8) DEFAULT '20231101',
    end_time VARCHAR(8) DEFAULT '20231130',
    is_deleted BOOLEAN DEFAULT FALSE,
    created_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP,
    created_by VARCHAR(255) DEFAULT 'admin',
    updated_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP,
    updated_by VARCHAR(255),
    deleted_at TIMESTAMP WITH TIME ZONE,
    deleted_by VARCHAR(255)
);

-- 触发器函数,用于自动更新updated_at和deleted_at
CREATE OR REPLACE FUNCTION update_timestamps()
RETURNS TRIGGER AS $$
BEGIN
    NEW.updated_at = CURRENT_TIMESTAMP;
    RETURN NEW;
END;
$$ LANGUAGE plpgsql;

-- 触发器,在UPDATE和DELETE时自动更新时间戳
CREATE TRIGGER  update_t_table_name_timestamps
BEFORE UPDATE ON t_table_name
FOR EACH ROW EXECUTE PROCEDURE public.update_timestamps();
  • 1.
  • 2.
  • 3.
  • 4.
  • 5.
  • 6.
  • 7.
  • 8.
  • 9.
  • 10.
  • 11.
  • 12.
  • 13.
  • 14.
  • 15.
  • 16.
  • 17.
  • 18.
  • 19.
  • 20.
  • 21.
  • 22.
  • 23.
  • 24.
  • 25.
  • 26.
  • 27.
  • 28.
  • 29.
  • 30.


2、实体类定义
import os
from sqlalchemy.orm import sessionmaker
from sqlalchemy import create_engine, select, update, delete, insert  
from sqlalchemy.orm import Session, declarative_base
# from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy import Column, Integer, BigInteger, String, Text, DATE, Boolean, TIMESTAMP, create_engine  
from sqlalchemy.sql import select
from urllib.parse import quote_plus
from sqlalchemy import text  
from sqlalchemy.exc import NoResultFound  
from sqlalchemy import Column, TIMESTAMP, func
from util import DbCommonUtils

Base = declarative_base()

class TableClass(Base):  
    __tablename__ = 't_table_name'  
  
    id = Column(BigInteger, primary_key=True, autoincrement=True)  # 注意:PostgreSQL 中 BIGSERIAL 通常用于自增主键  
    custom_id = Column(String(20), unique=True, nullable=False)  
    name = Column(String(255), nullable=False)  
    pt = Column(String(255), nullable=False)  
    start_time = Column(String(8), default='20231101')  
    end_time = Column(String(8), default='20231130')  
    is_deleted = Column(Boolean, default=False)  
    created_at = Column(TIMESTAMP(timezone=True), server_default='CURRENT_TIMESTAMP')  
    created_by = Column(String(255), default='admin')  
    updated_at = Column(TIMESTAMP(timezone=True), server_default='CURRENT_TIMESTAMP', onupdate=func.current_timestamp())  
    updated_by = Column(String(255))  
    deleted_at = Column(TIMESTAMP(timezone=True))  
    deleted_by = Column(String(255))  
  
    def __repr__(self):  
    # 创建一个列表来存储每个字段的字符串表示  
        fields = [  
            f"ID: {self.id}",  
            f"Name: {self.name}",  
            f"End Time: {self.end_time}",  
            f"Is Deleted: {self.is_deleted}",  
            f"Created At: {self.created_at}",  
            f"Created By: {self.created_by}",  
            f"Updated At: {self.updated_at}",  
            f"Updated By: {self.updated_by}",  
            f"Deleted At: {self.deleted_at}",  
            f"Deleted By: {self.deleted_by}",  
        ]  
        
        # 使用换行符连接所有字段,并返回最终的字符串  
        return '\n'.join(fields) 
  • 1.
  • 2.
  • 3.
  • 4.
  • 5.
  • 6.
  • 7.
  • 8.
  • 9.
  • 10.
  • 11.
  • 12.
  • 13.
  • 14.
  • 15.
  • 16.
  • 17.
  • 18.
  • 19.
  • 20.
  • 21.
  • 22.
  • 23.
  • 24.
  • 25.
  • 26.
  • 27.
  • 28.
  • 29.
  • 30.
  • 31.
  • 32.
  • 33.
  • 34.
  • 35.
  • 36.
  • 37.
  • 38.
  • 39.
  • 40.
  • 41.
  • 42.
  • 43.
  • 44.
  • 45.
  • 46.
  • 47.
  • 48.
  • 49.

1.2 初始化连接

connection_string = connection_string = f'postgresql+psycopg2://{username}:{password}@{host}:{port}/{dbname}'  
Base = declarative_base()

class TableClassInvoke(object):  
    def __init__(self):  
        # 创建数据库引擎  
        self.engine = create_engine(connection_string)  # 假设connection_string已经定义  
        Session = sessionmaker(bind=self.engine)  
        self.session = Session()  
        # 连接到数据库后设置默认 schema  
        with self.engine.connect() as connection:  
            # 设置当前会话的默认 schema  
            connection.execute(text("SET search_path TO public")) 
        Base.metadata.create_all(self.engine)
 
  • 1.
  • 2.
  • 3.
  • 4.
  • 5.
  • 6.
  • 7.
  • 8.
  • 9.
  • 10.
  • 11.
  • 12.
  • 13.
  • 14.
  • 15.

1.3 增

1、单条插入
def add_new_info(self, name: str, pt: str,user_name: str,  wkt: str, **kwargs):  
        """  
        添加新的城市任务信息到数据库。  
    
        :param session: SQLAlchemy会话对象  
        :param custom_id: 自定义ID,唯一且非空  
        :param name: 任务名称,非空  
        :param pt: 分区日期,非空  
        :param kwargs: 其他可选参数等  
        :return: 新保存的记录对象(或ID,取决于实现)  
        """  
        # 创建一个新的实例  
        new_record = TableClass(  
            name=name,  
            pt=pt, 
            created_by =user_name, 
            wkt=wkt
        )  
        
        # 将新记录添加到会话中  
        self.session.add(new_record)  
        
        # 提交会话以保存到数据库  
        self.session.commit()  
        
        # 返回新保存的记录对象(或根据需要返回其他信息)  
        return new_record  
  • 1.
  • 2.
  • 3.
  • 4.
  • 5.
  • 6.
  • 7.
  • 8.
  • 9.
  • 10.
  • 11.
  • 12.
  • 13.
  • 14.
  • 15.
  • 16.
  • 17.
  • 18.
  • 19.
  • 20.
  • 21.
  • 22.
  • 23.
  • 24.
  • 25.
  • 26.
  • 27.
2、批量插入
    def insert_batch(self, table_name,pt, city, records):
        # 分区幂等处理
        partition_tb_name =text( f"{table_name}_{city}_{pt}")
        
        # 使用原生 SQL 创建分区表
        drop_partition_stmt = text(f"DROP TABLE IF EXISTS {partition_tb_name};")
        create_partition_stmt = text(
            f"CREATE TABLE {partition_tb_name} "
            f"(CHECK (pt = '{pt}' AND city = '{city}')) "
            f"INHERITS ({table_name});"
        )
        
        # 执行 SQL 语句创建分区
        self.conn.execute(drop_partition_stmt)
        self.conn.commit()
        self.conn.execute(create_partition_stmt)
        self.conn.commit()
        # 获取目标表
        target_table = Table(partition_tb_name, MetaData(), autoload_with=self.engine)
        
        # 批量插入数据
        stmt = pg_insert(target_table).values(records)
        
        # 使用 ON CONFLICT DO NOTHING 来处理重复数据
        # do_nothing_stmt = stmt.on_conflict_do_nothing(index_elements=[target_table.c.id])
        
        # 执行插入
        self.conn.execute(stmt)
        self.conn.commit()
        
        print('insert records...')
  • 1.
  • 2.
  • 3.
  • 4.
  • 5.
  • 6.
  • 7.
  • 8.
  • 9.
  • 10.
  • 11.
  • 12.
  • 13.
  • 14.
  • 15.
  • 16.
  • 17.
  • 18.
  • 19.
  • 20.
  • 21.
  • 22.
  • 23.
  • 24.
  • 25.
  • 26.
  • 27.
  • 28.
  • 29.
  • 30.
  • 31.


1.4 删

1、条件删除
    def delete_record(self, input_id):  
        """删除一条记录"""  
        stmt = delete(TableClass).where(TableClass.id == input_id)  
        self.session.execute(stmt)  
        self.session.commit()  
  • 1.
  • 2.
  • 3.
  • 4.
  • 5.


1.5 改

1、根据id
    def update_record(self, input_id, record_data):  
        """更新一条记录"""  
        stmt = update(TableClass).where(TableClass.id == input_id).values(**record_data)  
        print(stmt)
        self.session.execute(stmt)  
        self.session.commit()  
  • 1.
  • 2.
  • 3.
  • 4.
  • 5.
  • 6.
2、调用函数更新
    def update_geom_col_value(self, partition_tb_name):
        alter_add_track_geom_stmt = text(f"alter table {partition_tb_name} add column if not exists geom geometry(Geometry, 4326);")
        self.conn.execute(alter_add_track_geom_stmt)
        self.conn.commit()
        update_geom_stmt = text(f"UPDATE {partition_tb_name} SET geom = ST_SetSRID(ST_GeomFromText(track), 4326) where track is not null;")
        self.conn.execute(update_geom_stmt)
        self.conn.commit()
  • 1.
  • 2.
  • 3.
  • 4.
  • 5.
  • 6.
  • 7.


1.6 查

1、指定输出字段条件查询
 
    def get_all_records(self):  
        """查询所有记录,可以选择性地返回字段"""  
        # return self.session.query(TableClass).filter_by(is_deleted=False)
        stmt = select(TableClass).where(TableClass.is_deleted==False)
        result = self.session.execute(stmt)
        records = [['id','name','分区城市编码','分区日期']]
        for meta in result.scalars():
            records.append([meta.id, meta.name, meta.city, meta.pt])
        return records
  • 1.
  • 2.
  • 3.
  • 4.
  • 5.
  • 6.
  • 7.
  • 8.
  • 9.
  • 10.
2、不指定字段查询
    def get_partition_all_records(self, pt, city):
        stmt = select(*TableClass.__table__.columns).filter_by(pt=pt, city=city)
        result = self.session.execute(stmt)
        return result.all()
  • 1.
  • 2.
  • 3.
  • 4.


二、SQLModel

参考相关资料2、

main传递db session参数

每执行一个SQL,执行一次db.commit()