fastapi 项目架构设计
结构树
fastapi_project
├─auth
│ │ base.py
│ └──__init__.py
│
├─event
│ │ startup_event.py
│ │ shutdown_event.py
│ └──__init__.py
│
├─exception
│ │ value.py
│ └──__init__.py
│
├─logs
│
├─medias
│
├─model
│ │ curd.py
│ │ database.py
│ │ models.py
│ │ schemas.py
│ └──__init__.py
│
├─response
│ │ base.py
│ └──__init__.py
│
├─router
│ │ base.py
│ └──__init__.py
│
├─util
│ └──__init__.py
│
├─websocket
│ │ base.py
│ └──__init__.py
|
│ logging.ini
│ main.py
└──setting.py
结构详情
auth
**描述:**存放项目所有认证相关脚本
base.py
# base.py
from fastapi import Header
from starlette.authentication import AuthenticationError
def verify_base(x_token: str = Header(...)):
"""
基础校验
:param x_token:
:return:
"""
pass
event
**描述:**存放项目启动执行脚本和项目结束执行脚本
startup_event.py
# startup_event.py
def start_services():
"""
项目启动之前需要干的事
"""
pass
startdown_event.py
# startdown_event.py
def end_services():
"""
项目结束之前需要干的事
"""
pass
except
**描述:**存放项目异常捕获脚本
base.py
# base.py
from fastapi.requests import Request
def base_exception_handler(request: Request, exc):
"""
请求过程中被捕获到的异常返回值定义统一结构
:param request: 请求
:param exc: 捕获到的异常类
:return:
"""
if isinstance(exc.args[0], str):
return {"code": 500, "data": None, "msg": exc.args[0]}
else:
data = [{'----'.join(i['loc']): i['msg']} for i in exc.errors()]
return {"code": 500, "data": data, "msg": "服务端错误"}
logs
**描述:**存放项目日志
medias
**描述:**存放项目静态文件
model
**描述:**存放项目数据库相关脚本
curd.py
# curd.py
from model.database import SessionLocal
def get_db():
db = SessionLocal()
try:
yield db
finally:
db.close()
# 所有与数据库相关的增删改查的操作方法
database.py
# database.py
from datetime import datetime
from sqlalchemy import create_engine, Column, Integer, DateTime
from sqlalchemy.orm import sessionmaker, as_declarative, declared_attr
from setting import DBConfig
HOST = DBConfig["host"]
PORT = DBConfig["port"]
USER = DBConfig["user"]
PWD = DBConfig["pwd"]
DB_NAME = DBConfig["database"]
SQLALCHEMY_DATABASE_URL = f'mysql+pymysql://{USER}:{PWD}@{HOST}:{PORT}/{DB_NAME}'
engine = create_engine(
SQLALCHEMY_DATABASE_URL, pool_pre_ping=True
)
SessionLocal = sessionmaker(autoflush=False, bind=engine)
"""
primary_key 是否为主键
unique 是否唯一
index 如果为True,为该列创建索引,提高查询效率
nullable 是否允许为空
default 默认值
name 在数据表中的字段映射
autoincrement 是否自动增长
onupdate 更新时执行的函数
comment 字段描述
"""
@as_declarative()
class Base:
id = Column(Integer, primary_key=True, unique=True, index=True, autoincrement=True, comment='ID')
create_time = Column(DateTime, default=datetime.now, comment="创建时间")
update_time = Column(DateTime, default=datetime.now, onupdate=datetime.now, comment="更新时间")
is_delete = Column(Integer, default=0, comment="逻辑删除:0=存在,1=删除")
__name__: str
# Generate __tablename__ automatically
@declared_attr
def __tablename__(cls) -> str:
return cls.__name__.lower()
models.py
# models.py
from model.database import Base, engine
# 数据库模型
if __name__ == "__main__":
Base.metadata.create_all(engine)
schemas.py
# schemas.py
from pydantic import BaseModel
class Model(BaseModel):
class Config:
orm_mode = True
# 项目所用到的所有数据结构,例如:请求体结构,响应体结构等
response
**描述:**存放项目格式化返回值结构脚本
response.py
# response.py
from typing import Any, Optional, Dict
from starlette.background import BackgroundTask
from starlette.responses import JSONResponse
class Response(JSONResponse):
def __init__(self,
code: int = 200,
msg: str = '操作成功',
data: Any = None,
status_code: int = 200,
headers: Optional[Dict[str, str]] = None,
background: Optional[BackgroundTask] = None
) -> None:
content = {
'code': code,
'msg': msg,
'data': data
}
status_code = int(code)
super().__init__(content=content, status_code=status_code, headers=headers, media_type='application/json',
background=background)
router
**描述:**存放项目所有接口
base.py
# base.py
import logging
from fastapi import APIRouter
router = APIRouter()
logger = logging.getLogger(__name__)
@router.get("/", tags=["后台:系统介绍"])
def index():
"""
测试路由
:return:
"""
return {"code": 200, "data": None, "msg": "欢迎使用系统"}
util
**描述:**存放项目所有插件脚本
SM4EncryptDecrypt.py
# SM4EncryptDecrypt.py
from gmssl.sm4 import CryptSM4, SM4_ENCRYPT, SM4_DECRYPT
import binascii
from setting import EncryptKey
class SM4EncryptDecrypt:
"""
国密sm4加解密
"""
def __init__(self):
self.encrypt_key = self.decrypt_key = EncryptKey # 需要加密和解密的key
self.crypt_sm4 = CryptSM4()
@staticmethod
def str_to_hex(hex_str):
"""
字符串转hex
:param hex_str: 字符串
:return: hex
"""
hex_data = hex_str.encode('utf-8')
str_bin = binascii.unhexlify(hex_data)
return str_bin.decode('utf-8')
def encrypt(self, value, salt):
"""
国密sm4加密
:param value: 待加密的字符串
:param salt: 盐
:return: sm4加密后的hex值
"""
crypt_sm4 = self.crypt_sm4
crypt_sm4.set_key(self.encrypt_key.encode(), SM4_ENCRYPT)
value = value + salt
encrypt_value = crypt_sm4.crypt_ecb(value.encode()) # bytes类型
return encrypt_value.hex()
def decrypt(self, encrypt_value):
"""
国密sm4解密
:param encrypt_value: 待解密的hex值
:return: 原字符串
"""
crypt_sm4 = self.crypt_sm4
crypt_sm4.set_key(self.decrypt_key.encode(), SM4_DECRYPT)
decrypt_value = crypt_sm4.crypt_ecb(bytes.fromhex(encrypt_value)) # bytes类型
return self.str_to_hex(decrypt_value.hex())
websocket
**描述:**存放项目所有websocket接口
base.py
# base.py
import logging
from fastapi import APIRouter, WebSocket
router = APIRouter()
logger = logging.getLogger(__name__)
@router.websocket("")
async def index(websocket: WebSocket):
"""
测试websocket
:return:
"""
try:
await websocket.accept()
while True:
await websocket.send_text("欢迎使用系统")
except Exception as e:
await websocket.close(reason=e.args[0])
logging.ini
# logging.ini
[loggers]
keys = root, gunicorn.error, gunicorn.access, uvicorn.error, uvicorn.access,
[handlers]
keys = error_file, access_file, consoleHandler
[formatters]
keys = generic, access, normalFormatter
[logger_root]
level = DEBUG
handlers = access_file, consoleHandler
[logger_]
level = INFO
handlers = access_file, consoleHandler
qualname =
propagate = 0
[logger_uvicorn.error]
level = INFO
handlers = error_file, consoleHandler
qualname = uvicorn.error
propagate = 0
[logger_uvicorn.access]
level = INFO
handlers = access_file, consoleHandler
qualname = uvicorn.access
propagate = 0
[logger_gunicorn.error]
level = INFO
handlers = error_file, consoleHandler
propagate = 1
qualname = gunicorn.error
[logger_gunicorn.access]
level = INFO
handlers = access_file, consoleHandler
propagate = 0
qualname = gunicorn.access
[handler_consoleHandler]
class = StreamHandler
level = INFO
formatter = normalFormatter
args = (sys.stdout,)
[handler_error_file]
class = logging.FileHandler
formatter = generic
args = ('logs/error.log',)
[handler_access_file]
class = logging.FileHandler
formatter = access
args = ('logs/access.log',)
[formatter_generic]
format = [%(asctime)s] %(levelname)s %(name)s %(funcName)s() L%(lineno)-4d call_trace=%(pathname)s L%(lineno)-4d in %(module)s: %(message)s
datefmt = %Y-%m-%d %H:%M:%S
class = logging.Formatter
[formatter_access]
format = [%(asctime)s] %(levelname)s %(name)s %(funcName)s() L%(lineno)-4d in %(module)s: %(message)s
datefmt = %Y-%m-%d %H:%M:%S
class = logging.Formatter
[formatter_normalFormatter]
format = %(asctime)s loglevel=%(levelname)-6s logger=%(name)s %(funcName)s() L%(lineno)-4d %(message)s
datefmt = %Y-%m-%d %H:%M:%S
class = logging.Formatter
main.py
# main.py
import logging.config
import os
import uvicorn
from fastapi import FastAPI
from starlette.middleware.cors import CORSMiddleware
from starlette.middleware.gzip import GZipMiddleware
from exception.base import base_exception_handler
from router import base as api_base
from websocket import base as socket_base
from response.base import Response
from event.startup_event import start_services
from event.startdown_event import end_services
logging.config.fileConfig(fname=os.path.join(os.path.dirname(os.path.abspath(__file__)), 'logging.ini'),
disable_existing_loggers=False)
app = FastAPI(default_response_class=Response)
app.add_event_handler('startup', start_services)
app.add_event_handler('startdown', end_services)
app.include_router(api_base.router, prefix="/api/v1")
app.include_router(socket_base.router, prefix="/socket/v1")
origins = [
"*",
]
app.add_middleware(
CORSMiddleware,
allow_origins=origins,
allow_credentials=True,
allow_methods=["GET", "POST", "PATCH", "DELETE"],
allow_headers=["*"]
)
app.add_middleware(GZipMiddleware, minimum_size=1000)
app.add_exception_handler(Exception, base_exception_handler)
if __name__ == "__main__":
uvicorn.run(app=app, host='0.0.0.0', port=5000, log_config="logging.ini", workers=30)
setting.py
# setting.py
# 数据库配置
DBConfig = {
"host": '127.0.0.1',
"port": 3306,
"user": 'root',
"***": '***',
"database": 'fastapi'
}
# 全局加密KEY
EncryptKey = ''
# Token过期时间
ExpireDelta = 60 * 60 * 24
# 文件地址
MEDIAS_ROOT = r''
MediasPath = 'medias'