概要
本文将详细介绍如何使用 FastAPI 构建一个功能强大的网关服务,该网关服务能够处理认证、路由转发和日志记录等功能。我们将基于提供的代码文件进行分析,并对代码进行必要的优化和补充。
整体架构流程
- 数据库模型 (base.py)
from typing import List
from sqlalchemy import or_
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.ext.declarative import as_declarative, declared_attr
from sqlalchemy.orm import Session
# 创建 Base 类,用于所有的 ORM 模型继承
@as_declarative()
class Base:
# 为所有模型添加 id 字段
id: int
@declared_attr
def __tablename__(cls) -> str:
# 自动生成表名,表名为类名的小写形式
return cls.__name__.lower()
# 这个方法可以方便定义通用的查询方法
@classmethod
def create(cls, db: Session, **kwargs):
"""
创建一个新实例并保存到数据库
:param db: 数据库会话
:param kwargs: 实例的字段值
:return: 创建的实例
"""
instance = cls(**kwargs)
db.add(instance)
db.commit()
db.refresh(instance)
return instance
@classmethod
def bulk_create(cls, db: Session, data: List[dict]) -> bool:
"""
批量创建多个实例,并保证事务的一致性
:param db: 数据库会话
:param data: 字典列表,每个字典代表一个实例的字段
:return: 创建成功返回 True,否则抛出异常
"""
try:
# 开始事务
with db.begin(): # 使用事务管理器,自动提交或回滚
instances = [cls(**item) for item in data]
db.add_all(instances) # 批量添加实例
db.commit() # 提交事务
return True
except SQLAlchemyError as e:
db.rollback() # 如果发生异常,回滚事务
raise Exception(f"批量插入失败: {str(e)}")
@classmethod
def update(cls, db: Session, instance_id: int, **kwargs):
"""
更新指定 ID 的实例
:param db: 数据库会话
:param instance_id: 实例的 ID
:param kwargs: 要更新的字段值
:return: 更新后的实例
"""
instance = db.query(cls).filter(cls.id == instance_id).first()
if instance:
for key, value in kwargs.items():
setattr(instance, key, value)
db.commit()
db.refresh(instance)
return instance
@classmethod
def delete(cls, db: Session, instance_id: int) -> bool:
"""
删除指定 ID 的实例
:param db: 数据库会话
:param instance_id: 实例的 ID
:return: 删除成功返回 True,否则返回 False
"""
instance = db.query(cls).filter(cls.id == instance_id).first()
if instance:
db.delete(instance)
db.commit()
return True
return False
@classmethod
def get(cls, db: Session, instance_id: int):
"""
根据 ID 获取实例
:param db: 数据库会话
:param instance_id: 实例的 ID
:return: 查询到的实例,未找到返回 None
"""
return db.query(cls).filter(cls.id == instance_id).first()
@classmethod
def all(cls, db: Session, search_params: dict = None, exact_match: bool = False):
"""
根据传入的查询条件进行查询。
:param db: 数据库会话
:param search_params: 查询参数字典,key 为字段名,value 为查询条件
:param exact_match: 是否为精确匹配,默认是模糊查询
:return: 查询结果列表
"""
query = db.query(cls)
if search_params:
filters = []
for field, value in search_params.items():
if hasattr(cls, field):
column = getattr(cls, field)
if exact_match:
filters.append(column == value) # 精确匹配
else:
filters.append(column.ilike(f"%{value}%")) # 模糊查询
if filters:
query = query.filter(or_(*filters))
return query.all()
- service.py
from fastapi import APIRouter, Request, HTTPException, status
from fastapi.responses import JSONResponse
from httpx import AsyncClient
from starlette.middleware.base import BaseHTTPMiddleware
from utils.logger_utils import logger
from .token_auth import get_current_user
router = APIRouter()
class AuthMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next):
"""
认证中间件,用于验证请求的合法性
:param request: 请求对象
:param call_next: 下一个中间件或路由处理函数
:return: 响应对象
"""
# 获取配置
app = request.app
white_list_paths = app.config.get("WHITE_LIST_PATHS", [])
# 输出日志
logger.debug(f"Request path: {request.url.path}")
logger.debug(f"White list paths: {white_list_paths}")
# 检查路径是否在白名单中
if white_list_paths and any(request.url.path.startswith(path) for path in white_list_paths):
response = await call_next(request)
return response
try:
# 验证token
token = request.headers.get("Authorization")
if not token:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Missing token")
token = token.split(" ")[1] # 去掉 "Bearer "
username = get_current_user(token)
logger.debug(f"Current username: {username}")
request.state.username = username
except HTTPException as e:
return JSONResponse(status_code=e.status_code, content={"detail": e.detail})
response = await call_next(request)
return response
@router.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE", "PATCH", "HEAD", "OPTIONS"])
async def gateway(request: Request, path: str):
"""
网关路由处理函数,负责将请求转发到相应的后端服务
:param request: 请求对象
:param path: 请求路径
:return: 响应对象
"""
app = request.app
config = app.config
# 获取服务映射字典
service_urls = config.get("SERVICE_URLS", {})
# 确定服务名称和路径
service_name = None
for service, url in service_urls.items():
if path.startswith(f"{service}/"):
service_name = service
break
if service_name:
backend_service_url = service_urls[service_name]
# 去掉服务名称部分
path = path[len(service_name) + 1:]
else:
return JSONResponse(status_code=404, content={"detail": "未知的服务"})
async with AsyncClient() as client:
# 构建转发请求的URL
url = f"{backend_service_url}/{path}"
headers = {key: value for key, value in request.headers.items() if key != "host"}
response = await client.request(
method=request.method,
url=url,
headers=headers,
content=await request.body(),
params=request.query_params
)
logger.debug(f"Forwarding response to {response}")
return JSONResponse(content=response.json(), status_code=response.status_code)
- main.py
import uvicorn
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from api.v1 import login, service
from nacos_config.nacos_utils import initialize_nacos, shutdown_nacos
app = FastAPI(title="ERP API", version="1.0.0")
# 添加CORS中间件
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# 添加认证中间件
app.add_middleware(service.AuthMiddleware)
# 路由注册
app.include_router(login.router, prefix="/login", tags=["Login"])
app.include_router(service.router, tags=["service"])
# 初始化 Nacos 配置
service_registry = None
@app.on_event("startup")
async def startup_event():
"""
应用启动时的事件处理函数
"""
global service_registry
service_registry = await initialize_nacos(app)
@app.on_event("shutdown")
async def shutdown_event():
"""
应用关闭时的事件处理函数
"""
if service_registry:
await shutdown_nacos(service_registry)
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=8001)
- token_auth.py
from fastapi import Depends, HTTPException, status
from fastapi.security import OAuth2PasswordBearer
from jose import JWTError
from sqlalchemy.orm import Session
from core.database import get_db
from utils.jwt_utils import decode_access_token
from utils.logger_utils import logger
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
def get_current_user(token: str = Depends(oauth2_scheme), db: Session = Depends(get_db)):
"""
从请求头中获取并验证 JWT Token
:param token: JWT Token
:param db: 数据库会话
:return: 用户名
"""
credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Could not validate credentials",
headers={"WWW-Authenticate": "Bearer"},
)
try:
payload = decode_access_token(token)
if payload is None:
raise credentials_exception
username: str = payload.get("sub")
logger.debug(f"Decoded payload: {payload}")
if username is None:
raise credentials_exception
except JWTError as e:
logger.error(f"Failed to decode JWT token: {e}")
raise credentials_exception
logger.debug(f"Current user: {username}")
return username
- logger_utils.py
import logging
# 配置日志
logging.basicConfig(level=logging.DEBUG, # 设置日志级别为DEBUG,输出所有级别的日志
format="%(asctime)s - %(levelname)s - %(message)s") # 设置日志格式
logger = logging.getLogger(__name__) # 获取日志器
- jwt_utils.py
from datetime import datetime, timedelta
from typing import Optional
from jose import JWTError, jwt
from utils.logger_utils import logger
SECRET_KEY = "your_secret_key"
ALGORITHM = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES = 30
def create_access_token(data: dict, expires_delta: Optional[timedelta] = None):
"""
创建 JWT Access Token
:param data: 包含用户信息的字典
:param expires_delta: Token 的过期时间
:return: JWT Token 字符串
"""
to_encode = data.copy()
if expires_delta:
expire = datetime.utcnow() + expires_delta
else:
expire = datetime.utcnow() + timedelta(minutes=15)
to_encode.update({"exp": expire})
encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
return encoded_jwt
def verify_token(token: str):
"""
验证 JWT Token
:param token: JWT Token 字符串
:return: 解码后的 payload 或 None
"""
try:
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
return payload
except JWTError:
return None
def decode_access_token(token: str):
"""
解码 JWT Access Token
:param token: JWT Token 字符串
:return: 解码后的 payload 或 None
"""
try:
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
logger.debug(f"Decoded payload: {payload}")
return payload
except JWTError:
logger.error("Failed to decode JWT token")
return None
- database.py
from typing import Generator
from fastapi import Request
from sqlalchemy import create_engine
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker
def get_db_config(request: Request):
"""
从请求对象中获取数据库配置
:param request: 请求对象
:return: 数据库 URL
"""
return request.app.db_config.get("DATABASE_URL")
def create_engine_with_config(request: Request):
"""
根据配置创建数据库引擎
:param request: 请求对象
:return: 数据库引擎
"""
DATABASE_URL = get_db_config(request)
engine = create_engine(DATABASE_URL, echo=True)
return engine
def get_sessionmaker(request: Request):
"""
根据配置创建 Session 类
:param request: 请求对象
:return: Session 类
"""
engine = create_engine_with_config(request)
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
return SessionLocal
# 创建 Base 类,用于定义所有数据库模型
Base = declarative_base()
def get_db(request: Request) -> Generator:
"""
获取数据库会话的生成器函数
:param request: 请求对象
:return: 数据库会话
"""
SessionLocal = get_sessionmaker(request)
db = SessionLocal() # 创建一个会话实例
try:
yield db # 将会话对象提供给依赖注入系统
finally:
db.close() # 请求结束后,自动关闭数据库会话
- nacos_uitls.py(上一篇文章有说明python读取nacos配置)
#!/usr/bin/env python
# -*- coding:utf-8 -*-
# @FileName :nacos_utils.py
# @Time :2025/2/14 下午3:57
# @Author :wxh
from nacos import NacosClient
from nacos_config.nacos_config import NACOS_SERVER_ADDRESSES, NACOS_NAMESPACE
from nacos_config.service_registry import ServiceRegistry
# 创建 Nacos 客户端实例
nacos_client = NacosClient(NACOS_SERVER_ADDRESSES, namespace=NACOS_NAMESPACE)
async def initialize_nacos(app):
"""
初始化 Nacos 配置和服务注册
:param app: FastAPI 应用实例
:return: 服务注册实例
"""
# 初始化服务注册
service_registry = ServiceRegistry()
# 注册服务到 Nacos
service_registry.register_service()
# 从 Nacos 获取配置
config = service_registry.get_config()
db_config = service_registry.get_db_config()
if config:
print(f"Config from Nacos: {config}")
print(f"dbConfig from Nacos: {db_config}")
# 将配置应用到应用中
app.config = config
app.db_config = db_config
else:
print("No config found in Nacos")
return service_registry
async def shutdown_nacos(service_registry):
"""
关闭 Nacos 配置和服务注册
:param service_registry: 服务注册实例
"""
# 注销服务从 Nacos
service_registry.deregister_service()
技术细节
- 文件目录
auth_service/
├── api/
│ └── v1/
│ ├── service.py
│ └── token_auth.py
├── core/
│ └── database.py
├── main.py
├── nacos_config/
│ ├── nacos_config.py
│ └── nacos_utils.py
└── utils/
├── jwt_utils.py
└── logger_utils.py
文件功能概述
- api/v1/service.py: 实现认证中间件和网关路由处理。
- api/v1/token_auth.py: 实现 JWT Token 的验证逻辑。
- core/database.py: 创建数据库引擎和会话管理。
- main.py: FastAPI 应用的主入口文件,负责初始化应用、注册中间件和路由。
- nacos_config/nacos_config.py: 包含 Nacos 的配置信息。
- nacos_config/nacos_utils.py: 实现 Nacos 的初始化和服务注册。
- utils/jwt_utils.py: 提供 JWT Token 的生成和验证功能。
- utils/logger_utils.py: 配置日志记录器。
- 数据库模型 (base.py)
from sqlalchemy.ext.declarative import as_declarative, declared_attr
@as_declarative()
class Base:
id: int
@declared_attr
def __tablename__(cls) -> str:
return cls.__name__.lower()
Base 类作为所有 ORM 模型的基类。
- create: 创建一个新实例并保存到数据库。
- bulk_create: 批量创建多个实例,并保证事务的一致性。
- update: 更新指定 ID 的实例。
- delete: 删除指定 ID 的实例。
- get: 根据 ID 获取实例。
- all: 根据传入的查询条件进行查询。
- 认证中间件 (service.py)
from fastapi import APIRouter, Request, HTTPException, status
from fastapi.responses import JSONResponse
from httpx import AsyncClient
from starlette.middleware.base import BaseHTTPMiddleware
from utils.logger_utils import logger
from .token_auth import get_current_user
router = APIRouter()
class AuthMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next):
app = request.app
white_list_paths = app.config.get("WHITE_LIST_PATHS", [])
logger.debug(f"Request path: {request.url.path}")
logger.debug(f"White list paths: {white_list_paths}")
if white_list_paths and any(request.url.path.startswith(path) for path in white_list_paths):
response = await call_next(request)
return response
try:
token = request.headers.get("Authorization")
if not token:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Missing token")
token = token.split(" ")[1]
username = get_current_user(token)
logger.debug(f"Current username: {username}")
request.state.username = username
except HTTPException as e:
return JSONResponse(status_code=e.status_code, content={"detail": e.detail})
response = await call_next(request)
return response
@router.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE", "PATCH", "HEAD", "OPTIONS"])
async def gateway(request: Request, path: str):
app = request.app
config = app.config
service_urls = config.get("SERVICE_URLS", {})
service_name = None
for service, url in service_urls.items():
if path.startswith(f"{service}/"):
service_name = service
break
if service_name:
backend_service_url = service_urls[service_name]
path = path[len(service_name) + 1:]
else:
return JSONResponse(status_code=404, content={"detail": "未知的服务"})
async with AsyncClient() as client:
url = f"{backend_service_url}/{path}"
headers = {key: value for key, value in request.headers.items() if key != "host"}
response = await client.request(
method=request.method,
url=url,
headers=headers,
content=await request.body(),
params=request.query_params
)
logger.debug(f"Forwarding response to {response}")
return JSONResponse(content=response.json(), status_code=response.status_code)
- AuthMiddleware: 继承自 BaseHTTPMiddleware,用于验证请求的合法性。
- dispatch: 处理请求,检查路径是否在白名单中,验证 JWT Token。
- gateway: 处理所有请求,根据路径将请求转发到相应的后端服务。
小结
通过上述代码实现,我们构建了一个完整的 FastAPI 网关服务,具备以下功能:
用户认证和授权
请求转发到后端服务
日志记录
数据库操作
这些功能使得该网关服务可以作为企业级应用的统一入口,确保系统的安全性和可扩展性。