1. 什么是依赖注入?
依赖注入是一种软件设计模式,它允许将一个对象(依赖)传递给另一个对象,而不是让接收对象自己创建依赖。通过这种方式,可以降低代码耦合度,提高代码的可测试性和可维护性,同时增强代码的重用性。
在 FastAPI 中,依赖注入主要指的是自动将功能或数据传递给路由处理函数的机制。
2. FastAPI依赖注入基础
2.1 基本语法
FastAPI使用Depends
函数来声明依赖。以下是一个简单的例子:
from fastapi import FastAPI, Depends
app = FastAPI()
def get_query_param(q: str = None):
return q
@app.get("/items/")
async def read_items(query: str = Depends(get_query_param)):
return {"query": query}
在这个例子中,get_query_param
是一个依赖函数。当有请求到/items/
路径时,FastAPI会自动调用get_query_param
函数,并将其结果作为query
参数传递给read_items
函数。
2.2 为什么使用依赖注入?
- 代码复用: 可以在多个路由中重复使用相同的依赖。
- 关注点分离: 可以将特定功能(如认证、数据库操作)分离到依赖中。
- 易于测试: 可以轻松地模拟或替换依赖,使单元测试更容易。
3. 深入理解FastAPI依赖注入
3.1 依赖类型
FastAPI 支持多种类型的依赖,包括函数依赖、类依赖和生成器依赖。
3.1.1 函数依赖
最简单的依赖类型是函数依赖,如:
def get_query_param(q: str = None):
return q
@app.get("/items/")
async def read_items(query: str = Depends(get_query_param)):
return {"query": query}
3.1.2 类依赖
类依赖允许创建更复杂的依赖结构,例如:
class CommonQueryParams:
def __init__(self, q: str = None, skip: int = 0, limit: int = 100):
self.q = q
self.skip = skip
self.limit = limit
@app.get("/items/")
async def read_items(commons: CommonQueryParams = Depends(CommonQueryParams)):
return {"q": commons.q, "skip": commons.skip, "limit": commons.limit}
3.1.3 生成器依赖
生成器依赖特别适用于需要设置和清理操作的场景,如数据库会话管理:
from fastapi import Depends
from sqlalchemy.orm import Session
def get_db():
db = SessionLocal()
try:
yield db
finally:
db.close()
@app.get("/users/{user_id}")
def read_user(user_id: int, db: Session = Depends(get_db)):
user = db.query(User).filter(User.id == user_id).first()
return user
3.2 依赖树
依赖可以依赖于其他依赖,形成依赖树,例如:
def dep_a():
return {"a": 1}
def dep_b(a = Depends(dep_a)):
return {"b": 2, **a}
@app.get("/test")
async def test(b = Depends(dep_b)):
return b # 将返回 {"a": 1, "b": 2}
4. 实际应用场景
4.1 用户认证
from fastapi import Depends, HTTPException
from fastapi.security import OAuth2PasswordBearer
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
def get_current_user(token: str = Depends(oauth2_scheme)):
user = decode_token(token)
if not user:
raise HTTPException(status_code=401, detail="Invalid token")
return user
@app.get("/users/me")
async def read_users_me(current_user: User = Depends(get_current_user)):
return current_user
4.2 数据库会话管理
from fastapi import Depends
from sqlalchemy.orm import Session
def get_db():
db = SessionLocal()
try:
yield db
finally:
db.close()
@app.get("/items/")
def read_items(skip: int = 0, limit: int = 100, db: Session = Depends(get_db)):
items = db.query(Item).offset(skip).limit(limit).all()
return items
4.3 请求日志记录
import time
from fastapi import Request, FastAPI, Depends
app = FastAPI()
async def log_request(request: Request):
start_time = time.time()
yield
process_time = time.time() - start_time
print(f"Request to {request.url.path} took {process_time:.6f} seconds")
@app.get("/items/", dependencies=[Depends(log_request)])
async def read_items():
return {"items": []}
if __name__ == '__main__':
import uvicorn
uvicorn.run(app)
5. 高级技巧
5.1 参数化依赖
有时我们需要依赖函数接受一些参数来动态生成依赖项。通过使用工厂函数,我们可以创建带参数的依赖项。
from fastapi import FastAPI, Depends
app = FastAPI()
def query_extractor(field: str):
def extract_query(query_param: str = None):
if not query_param:
return None
return {field: query_param}
return extract_query
@app.get("/items/")
async def read_items(query=Depends(query_extractor("item"))):
return {"query": query}
if __name__ == '__main__':
import uvicorn
uvicorn.run(app)
在这个例子中,query_extractor 是一个工厂函数,它接受一个参数 field 并返回一个内部函数 extract_query。当 read_items 路由被访问时,query_extractor(“item”) 被调用,生成一个依赖函数 extract_query,这个函数处理传入的查询参数并返回一个包含该参数的字典。
5.2 全局依赖
有时,我们希望某些依赖项对应用中的所有路由都生效。我们可以通过在创建 FastAPI 实例时传递 dependencies 参数来实现这一点。
from fastapi.security import OAuth2PasswordBearer
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
api_key_header = "x-api-key"
def verify_token(token: str = Depends(oauth2_scheme)):
# 验证 token 的逻辑
if token != "valid-token":
raise HTTPException(status_code=401, detail="Invalid token")
return token
def verify_key(key: str = Depends(api_key_header)):
# 验证 API key 的逻辑
if key != "valid-key":
raise HTTPException(status_code=401, detail="Invalid key")
return key
app = FastAPI(dependencies=[Depends(verify_token), Depends(verify_key)])
在这个例子中,verify_token 和 verify_key 函数将作为全局依赖应用于所有路由,确保每个请求都需要通过 token 和 API key 的验证。
5.3 依赖覆盖
在测试中,我们经常需要覆盖某些依赖项,以便模拟特定的条件或行为。FastAPI 提供了 dependency_overrides 属性来实现这一点。
from fastapi import FastAPI, Depends
from typing import Dict
app = FastAPI()
# 假设 original_dep 是我们需要覆盖的依赖
def original_dep():
return {"original": "data"}
@app.get("/test")
async def a(dep: Dict[str, str] = Depends(original_dep)):
return dep
# 覆盖依赖
def override_dep():
return {"test": "data"}
app.dependency_overrides[original_dep] = override_dep
if __name__ == '__main__':
import uvicorn
uvicorn.run(app)
在这个例子中,我们定义了一个 override_dep 函数来覆盖 original_dep 依赖。通过将 override_dep 分配给 app.dependency_overrides[original_dep],我们可以在测试中使用 override_dep 代替 original_dep。
6. 性能考虑
FastAPI默认会缓存依赖的结果。这意味着对于同一个请求,如果多个路由或依赖使用了相同的依赖,该依赖只会被调用一次。这种行为通常能提高性能,但在某些情况下可能不是我们想要的。
可以通过设置use_cache=False
来禁用这个行为:
from fastapi import FastAPI, Depends
app = FastAPI()
def heavy_computation():
# 一些耗时的操作
return "ok"
@app.get("/items/")
async def read_items(result: dict = Depends(heavy_computation, use_cache=False)):
return result
在这个例子中,每次调用read_items
时,heavy_computation
都会被重新执行。
7. 实际项目中的应用
以下是一个复杂的例子,展示了如何在实际项目中组合使用多个依赖来处理复杂的业务逻辑:
from fastapi import FastAPI, Depends, HTTPException
from fastapi.security import OAuth2PasswordBearer
from sqlalchemy.orm import Session
from sqlalchemy import create_engine, Column, Integer, String, Boolean
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker
from typing import List
DATABASE_URL = "sqlite:///./test.db"
# 数据库设置
engine = create_engine(DATABASE_URL)
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
Base = declarative_base()
# 数据库模型
class User(Base):
__tablename__ = "users"
id = Column(Integer, primary_key=True, index=True)
username = Column(String, unique=True, index=True)
is_admin = Column(Boolean, default=False)
class UserOut(Base):
__tablename__ = "users_out"
id = Column(Integer, primary_key=True, index=True)
username = Column(String, unique=True, index=True)
class Item(Base):
__tablename__ = "items"
id = Column(Integer, primary_key=True, index=True)
title = Column(String, index=True)
owner_id = Column(Integer)
class ItemCreate(Base):
__tablename__ = "items_create"
title = Column(String, index=True)
# 创建数据库表
Base.metadata.create_all(bind=engine)
# FastAPI 实例
app = FastAPI()
# OAuth2 认证设置
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
# 模拟认证函数
def authenticate_user(token: str, db: Session):
# 这里应该实现实际的用户认证逻辑
user = db.query(User).filter(User.id == token).first()
return user
# 数据库依赖
def get_db():
db = SessionLocal()
try:
yield db
finally:
db.close()
# 用户认证依赖
def get_current_user(token: str = Depends(oauth2_scheme), db: Session = Depends(get_db)):
user = authenticate_user(token, db)
if not user:
raise HTTPException(status_code=401, detail="无效的身份验证凭据")
return user
# 管理员权限检查依赖
def check_admin(current_user: User = Depends(get_current_user)):
if not current_user.is_admin:
raise HTTPException(status_code=403, detail="权限不足")
return current_user
# 分页参数依赖
class PaginationParams:
def __init__(self, skip: int = 0, limit: int = 100):
self.skip = skip
self.limit = limit
# 路由
@app.get("/users/", response_model=List[UserOut])
def read_users(
pagination: PaginationParams = Depends(),
db: Session = Depends(get_db),
current_user: User = Depends(check_admin)
):
users = db.query(User).offset(pagination.skip).limit(pagination.limit).all()
return users
@app.post("/items/")
def create_item(
item: ItemCreate,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
db_item = Item(**item.dict(), owner_id=current_user.id)
db.add(db_item)
db.commit()
db.refresh(db_item)
return db_item
代码解析
- 数据库设置:配置数据库连接和模型,并创建数据库表。
- 认证设置:配置 OAuth2 认证,定义认证函数 authenticate_user。
- 依赖函数:
- get_db:管理数据库会话。
- get_current_user:通过 token 获取当前用户。
- check_admin:检查用户是否为管理员。
- 分页参数依赖:定义一个 PaginationParams 类,用于处理分页参数。
- 路由:
- read_users:返回用户列表,仅管理员可访问。
- create_item:创建新项目,仅已认证用户可访问。
通过这种方式,我们可以利用 FastAPI 的依赖注入系统,在实际项目中实现复杂的业务逻辑。