接上篇:一文掌握异步web框架FastAPI(四)-CSDN博客
目录
16、访问速率限制中间件,即限制每个IP特定时间内的请求数(基于内存,生产上要使用数据库)
2)增加限制单ip并发(跟上面的一样,也是限制每个IP特定时间内的请求数,另一种写法)
七、中间件
15、测试环境中间件
这个中间件用于识别请求是否来自于测试环境,并采取相应的措施,如禁用缓存或跳过某些安全检查。
from fastapi import FastAPI, Request, HTTPException
from fastapi.responses import JSONResponse
import os
import logging
# 配置日志
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
app = FastAPI()
@app.middleware("http")
async def test_environment_middleware(request: Request, call_next):
try:
# 检查环境变量
test_environment_enabled = os.getenv("TEST_ENVIRONMENT_ENABLED", "false").lower() == "true"
if "X-Test-Environment" in request.headers and request.headers["X-Test-Environment"].lower() == "true":
if test_environment_enabled:
logger.info("Test environment header detected. Enabling test environment specific logic.")
# 测试环境特有的处理
# 例如:设置数据库连接为测试数据库、开启调试模式等
pass
else:
logger.warning("Test environment header detected, but TEST_ENVIRONMENT_ENABLED is not set to true.")
return JSONResponse({"message": "Test environment not enabled"}, status_code=403)
response = await call_next(request)
return response
except Exception as e:
logger.error(f"Error in test_environment_middleware: {str(e)}")
return JSONResponse({"message": "Internal server error"}, status_code=500)
@app.get("/")
async def root():
return {"message": "Hello World"}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)
请求:
import requests
# FastAPI应用的URL
url = "http://127.0.0.1:8000/"
# 发送GET请求到根路径,并包含测试环境头部
response = requests.get(url, headers={"X-Test-Environment": "true"})
# 打印响应状态码和内容
print(f"Status Code: {response.status_code}")
print(f"Response Content: {response.json()}")
本地未启用测试环境:
16、访问速率限制中间件,即限制每个IP特定时间内的请求数(基于内存,生产上要使用数据库)
1)限制单ip访问速率
from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse
import time
from collections import defaultdict
import logging
app = FastAPI()
# 初始化rate_limits字典
rate_limits = defaultdict(list)
# 配置日志
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger(__name__)
# 定义中间件
@app.middleware("http")
async def rate_limit_middleware(request: Request, call_next):
ip = request.headers.get("X-Forwarded-For", request.client.host).split(",")[0].strip()
now = time.time()
# 移除旧的请求时间戳
rate_limits[ip] = [t for t in rate_limits[ip] if t > now - 60]
# 检查速率限制
if len(rate_limits[ip]) >= 10: # 限制为每分钟10个请求
logger.debug(f"Rate limit exceeded for IP: {ip}")
return JSONResponse(
status_code=429,
content={"detail": "Too Many Requests"}
)
# 添加当前请求的时间戳
rate_limits[ip].append(now)
# 继续处理请求
response = await call_next(request)
return response
# 定义路由
@app.get("/")
async def root():
return {"message": "Hello World"}
# 启动应用
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000, log_level="debug")
请求:
import requests
import time
# FastAPI应用的URL
url = "http://127.0.0.1:8000/"
# 用于测试的IP地址
ip = "127.0.0.1"
# 设置请求头,模拟X-Forwarded-For
headers = {
"X-Forwarded-For": ip
}
# 发送多个请求以测试速率限制
def test_rate_limit():
for i in range(15):
try:
response = requests.get(url, headers=headers)
print(f"Request {i + 1}: Status Code: {response.status_code}")
if response.status_code == 429:
print(f"Request {i + 1}: {response.json()}")
except Exception as e:
print(f"Request {i + 1}: Error: {e}")
time.sleep(1) # 等待1秒