【核心精通】Django中间件开发与应用:10个高级技巧打造强大Web应用
前言:中间件如何成为Django应用的隐形强者?
在Django的架构中,中间件(Middleware)扮演着至关重要的角色,它们像无形的守卫一样拦截请求和响应,可以在不修改现有视图代码的情况下为整个应用添加功能。根据统计,超过90%的Django项目都使用了至少3个自定义中间件来解决跨领域关注点,如安全、性能监控和用户跟踪等。然而,大多数开发者对中间件的了解仅限于使用内置组件,而未充分发挥其潜力。本文将深入探讨Django中间件的开发与应用,揭示如何通过10个高级技巧创建强大且高效的中间件组件,为你的Django项目增添新的维度。
1. 中间件基础概念
1.1 什么是Django中间件?
中间件是Django请求/响应处理流程中的钩子机制,能够处理全局性的问题,如安全检查、会话管理、用户认证等。它们被配置在应用的请求处理管道中,按顺序处理每个请求和响应。
# settings.py中的中间件配置
MIDDLEWARE = [
'django.middleware.security.SecurityMiddleware',
'django.contrib.sessions.middleware.SessionMiddleware',
'django.middleware.common.CommonMiddleware',
'django.middleware.csrf.CsrfViewMiddleware',
'django.contrib.auth.middleware.AuthenticationMiddleware',
'django.contrib.messages.middleware.MessageMiddleware',
'django.middleware.clickjacking.XFrameOptionsMiddleware',
# 自定义中间件
'myapp.middleware.custom_middleware.CustomMiddleware',
]
1.2 请求/响应流程与中间件执行顺序
中间件在请求/响应流程中的执行顺序非常重要:
- 请求从顶到底通过中间件栈
- 到达视图函数
- 响应从底到顶返回通过中间件栈
请求处理流程:
Browser → [Security] → [Session] → [CSRF] → [Auth] → ... → View
响应处理流程:
View → ... → [Auth] → [CSRF] → [Session] → [Security] → Browser
中间件顺序会影响功能,例如SessionMiddleware
必须在AuthenticationMiddleware
之前,因为后者依赖于前者创建的会话。
1.3 中间件类型与钩子方法
Django中间件类可以定义以下方法:
class SimpleMiddleware:
def __init__(self, get_response):
self.get_response = get_response
# 一次性配置和初始化
def __call__(self, request):
# 处理视图调用前的请求
response = self.get_response(request)
# 处理视图调用后的响应
return response
def process_view(self, request, view_func, view_args, view_kwargs):
# 在视图函数调用前执行
return None
def process_exception(self, request, exception):
# 视图引发异常时调用
return None
def process_template_response(self, request, response):
# 对模板响应进行处理
return response
2. 创建第一个中间件
2.1 中间件基本结构
让我们创建一个简单的中间件,用于记录请求处理时间:
# myapp/middleware/timing.py
import time
import logging
logger = logging.getLogger(__name__)
class RequestTimingMiddleware:
def __init__(self, get_response):
self.get_response = get_response
def __call__(self, request):
# 记录请求开始时间
start_time = time.time()
# 调用视图
response = self.get_response(request)
# 计算请求处理时间
duration = time.time() - start_time
logger.info(f"Request to {request.path} took {duration:.4f}s")
# 添加处理时间到响应头
response['X-Request-Time'] = str(duration)
return response
在settings.py
中启用中间件:
MIDDLEWARE = [
# ...其他中间件
'myapp.middleware.timing.RequestTimingMiddleware',
]
2.2 同步和异步中间件
从Django 3.1开始,中间件支持异步处理:
# 同步中间件
class SyncMiddleware:
def __init__(self, get_response):
self.get_response = get_response
def __call__(self, request):
# 同步处理
return self.get_response(request)
# 异步中间件
class AsyncMiddleware:
def __init__(self, get_response):
self.get_response = get_response
async def __call__(self, request):
# 异步处理请求
response = await self.get_response(request)
return response
混合环境中的兼容性处理:
import asyncio
class CompatibleMiddleware:
def __init__(self, get_response):
self.get_response = get_response
# 确定get_response是同步或异步
self.is_async = asyncio.iscoroutinefunction(get_response)
async def __call__(self, request):
if self.is_async:
# 异步处理
return await self.get_response(request)
else:
# 同步处理在异步环境下
return await asyncio.to_thread(self.get_response, request)
2.3 中间件的启用与禁用
根据条件动态启用/禁用中间件:
class ConditionalMiddleware:
def __init__(self, get_response):
self.get_response = get_response
# 从配置决定是否启用
from django.conf import settings
self.enabled = getattr(settings, 'ENABLE_CUSTOM_MIDDLEWARE', True)
def __call__(self, request):
if not self.enabled:
# 中间件被禁用,直接传递请求
return self.get_response(request)
# 中间件启用,执行额外逻辑
# ...
response = self.get_response(request)
# 处理响应
# ...
return response
3. 常见中间件应用场景
3.1 用户认证与授权中间件
实现基于JWT的认证中间件:
# myapp/middleware/jwt_auth.py
import jwt
from django.contrib.auth.models import AnonymousUser
from django.conf import settings
from django.utils.functional import SimpleLazyObject
from myapp.models import User
def get_user_from_token(request):
# 从请求头获取令牌
auth_header = request.headers.get('Authorization', '')
if not auth_header.startswith('Bearer '):
return AnonymousUser()
token = auth_header.split(' ')[1]
try:
# 验证令牌
payload = jwt.decode(
token,
settings.JWT_SECRET_KEY,
algorithms=['HS256']
)
# 获取用户
user_id = payload.get('user_id')
if user_id:
return User.objects.get(id=user_id)
except (jwt.InvalidTokenError, User.DoesNotExist):
pass
return AnonymousUser()
class JWTAuthMiddleware:
def __init__(self, get_response):
self.get_response = get_response
def __call__(self, request):
# 使用SimpleLazyObject延迟加载用户
# 只有在需要时才验证令牌和查询数据库
request.user = SimpleLazyObject(lambda: get_user_from_token(request))
return self.get_response(request)
3.2 请求/响应修改中间件
添加安全头的中间件:
class SecurityHeadersMiddleware:
def __init__(self, get_response):
self.get_response = get_response
def __call__(self, request):
response = self.get_response(request)
# 添加安全相关的HTTP头
response['Content-Security-Policy'] = "default-src 'self'"
response['X-Content-Type-Options'] = 'nosniff'
response['X-Frame-Options'] = 'DENY'
response['X-XSS-Protection'] = '1; mode=block'
response['Referrer-Policy'] = 'strict-origin-when-cross-origin'
# 仅在HTTPS时设置
if request.is_secure():
age = 60 * 60 * 24 * 365 # 1年
response['Strict-Transport-Security'] = f'max-age={age}; includeSubDomains'
return response
3.3 异常处理中间件
创建全局异常处理中间件:
import json
import traceback
import logging
from django.http import JsonResponse, HttpResponse
from django.conf import settings
logger = logging.getLogger(__name__)
class GlobalExceptionMiddleware:
def __init__(self, get_response):
self.get_response = get_response
def __call__(self, request):
try:
return self.get_response(request)
except Exception as exc:
return self.handle_exception(request, exc)
def handle_exception(self, request, exception):
# 记录异常
logger.error(
f"Unhandled exception: {str(exception)}\n"
f"URL: {request.path}\n"
f"{traceback.format_exc()}"
)
# 确定是否为API请求
is_api_request = request.path.startswith('/api/') or \
request.headers.get('Accept') == 'application/json'
# API请求返回JSON响应
if is_api_request:
return self.handle_api_exception(request, exception)
# 常规请求渲染错误页面
return self.handle_html_exception(request, exception)
def handle_api_exception(self, request, exception):
status_code = getattr(exception, 'status_code', 500)
data = {
'error': True,
'message': str(exception),
'detail': str(exception) if settings.DEBUG else None,
'type': exception.__class__.__name__
}
return JsonResponse(data, status=status_code)
def handle_html_exception(self, request, exception):
if settings.DEBUG:
# 在DEBUG模式下重新抛出异常,显示详细信息
raise
# 生产环境返回友好的错误页面
from django.template.response import TemplateResponse
return TemplateResponse(
request,
'errors/500.html',
context={'error_message': str(exception)},
status=500
)
3.4 性能监控中间件
创建性能监控中间件:
import time
import uuid
from django.db import connection
from django.conf import settings
class PerformanceMonitoringMiddleware:
def __init__(self, get_response):
self.get_response = get_response
def __call__(self, request):
# 生成请求ID,用于跟踪整个请求过程
request_id = str(uuid.uuid4())
request.request_id = request_id
# 记录开始时间和数据库查询次数
start_time = time.time()
start_queries = len(connection.queries)
# 处理请求
response = self.get_response(request)
# 计算处理时间和数据库查询统计
duration = time.time() - start_time
db_queries = len(connection.queries) - start_queries
# 将统计信息添加到响应头
response['X-Request-ID'] = request_id
response['X-Request-Time'] = f"{duration:.4f}s"
response['X-DB-Queries'] = str(db_queries)
# 记录慢请求
threshold = getattr(settings, 'SLOW_REQUEST_THRESHOLD', 1.0)
if duration > threshold:
self.log_slow_request(request, duration, db_queries)
return response
def log_slow_request(self, request, duration, db_queries):
import logging
logger = logging.getLogger('performance')
logger.warning(
f"Slow request detected: {request.method} {request.path} "
f"took {duration:.4f}s with {db_queries} DB queries. "
f"Request ID: {request.request_id}"
)
# 在DEBUG模式下记录数据库查询详情
if settings.DEBUG:
for i, query in enumerate(connection.queries[-db_queries:]):
logger.debug(f"Query {i+1}: {query['sql']} ({query['time']}s)")
4. 高级中间件技巧
4.1 中间件链通信
使用请求对象在中间件之间传递数据:
class FirstMiddleware:
def __init__(self, get_response):
self.get_response = get_response
def __call__(self, request):
# 在请求对象上设置属性
request.first_middleware_data = "Data from first middleware"
# 处理请求
response = self.get_response(request)
# 从响应中读取其他中间件设置的属性
second_data = getattr(response, 'second_middleware_data', None)
print(f"FirstMiddleware received: {second_data}")
return response
class SecondMiddleware:
def __init__(self, get_response):
self.get_response = get_response
def __call__(self, request):
# 读取第一个中间件设置的属性
first_data = getattr(request, 'first_middleware_data', None)
print(f"SecondMiddleware received: {first_data}")
# 处理请求
response = self.get_response(request)
# 在响应对象上设置属性
response.second_middleware_data = "Data from second middleware"
return response
4.2 基于请求类型的条件处理
根据请求类型应用不同逻辑:
class SmartProcessingMiddleware:
def __init__(self, get_response):
self.get_response = get_response
def __call__(self, request):
# 根据请求路径决定处理方式
if request.path.startswith('/api/'):
return self.process_api_request(request)
elif request.path.startswith('/admin/'):
return self.process_admin_request(request)
else:
return self.process_regular_request(request)
def process_api_request(self, request):
# API请求的处理逻辑
start_time = time.time()
response = self.get_response(request)
response['X-API-Time'] = f"{time.time() - start_time:.4f}s"
return response
def process_admin_request(self, request):
# 管理界面请求的处理逻辑
if not request.user.is_staff:
from django.contrib.auth.views import redirect_to_login
return redirect_to_login(request.get_full_path())
return self.get_response(request)
def process_regular_request(self, request):
# 普通请求的处理逻辑
return self.get_response(request)
4.3 实现请求速率限制
创建IP地址速率限制中间件:
import time
from django.core.cache import cache
from django.http import HttpResponse
class RateLimitMiddleware:
def __init__(self, get_response):
self.get_response = get_response
# 配置参数
from django.conf import settings
self.rate_limit = getattr(settings, 'RATE_LIMIT', 100) # 请求/分钟
self.window = getattr(settings, 'RATE_LIMIT_WINDOW', 60) # 窗口期秒数
self.exempt_paths = getattr(settings, 'RATE_LIMIT_EXEMPT', ['/health/', '/ping/'])
def __call__(self, request):
# 对特定路径豁免
if any(request.path.startswith(path) for path in self.exempt_paths):
return self.get_response(request)
# 获取客户端IP
ip = self.get_client_ip(request)
# 检查速率限制
if self.is_rate_limited(ip):
return HttpResponse(
"Rate limit exceeded. Please try again later.",
status=429
)
# 正常处理请求
return self.get_response(request)
def get_client_ip(self, request):
x_forwarded_for = request.META.get('HTTP_X_FORWARDED_FOR')
if x_forwarded_for:
ip = x_forwarded_for.split(',')[0].strip()
else:
ip = request.META.get('REMOTE_ADDR')
return ip
def is_rate_limited(self, ip):
cache_key = f"rate_limit:{ip}"
# 获取当前计数
requests = cache.get(cache_key, [])
# 清除窗口期外的请求
current_time = time.time()
window_start = current_time - self.window
requests = [r for r in requests if r > window_start]
# 检查请求数量是否超过限制
if len(requests) >= self.rate_limit:
return True
# 添加新请求时间戳并更新缓存
requests.append(current_time)
cache.set(cache_key, requests, self.window)
return False
4.4 请求/响应修改中间件
实现请求体和响应体修改的中间件:
import json
from django.http import HttpResponse
from django.utils.deprecation import MiddlewareMixin
class RequestBodyModifierMiddleware(MiddlewareMixin):
"""修改传入的JSON请求体"""
def process_request(self, request):
if request.content_type == 'application/json' and request.body:
try:
# 解析JSON请求体
body_data = json.loads(request.body)
# 修改数据
if 'timestamp' not in body_data:
body_data['timestamp'] = int(time.time())
# 将修改后的数据保存回请求对象
request._body = json.dumps(body_data).encode('utf-8')
except json.JSONDecodeError:
# 非有效JSON,不做修改
pass
return None
class ResponseBodyModifierMiddleware:
"""修改输出的JSON响应体"""
def __init__(self, get_response):
self.get_response = get_response
def __call__(self, request):
response = self.get_response(request)
# 仅处理JSON响应
if response.get('Content-Type') == 'application/json':
self.modify_json_response(response)
return response
def modify_json_response(self, response):
try:
# 解析响应内容
content = json.loads(response.content.decode('utf-8'))
# 添加元数据
content['meta'] = {
'version': '1.0',
'timestamp': int(time.time())
}
# 更新响应内容
response.content = json.dumps(content).encode('utf-8')
response['Content-Length'] = str(len(response.content))
except json.JSONDecodeError:
# 非有效JSON,不做修改
pass
5. 处理多语言和本地化
5.1 语言检测与设置中间件
创建自动语言检测中间件:
from django.utils import translation
from django.conf import settings
class LanguageMiddleware:
def __init__(self, get_response):
self.get_response = get_response
self.default_language = settings.LANGUAGE_CODE
self.supported_languages = [lang[0] for lang in settings.LANGUAGES]
def __call__(self, request):
# 获取优先语言
language = self.get_language(request)
# 设置当前请求的语言
translation.activate(language)
request.LANGUAGE_CODE = language
# 处理请求
response = self.get_response(request)
# 为响应设置语言Cookie
response.set_cookie(
settings.LANGUAGE_COOKIE_NAME,
language,
max_age=settings.LANGUAGE_COOKIE_AGE
)
return response
def get_language(self, request):
"""按优先级获取语言"""
# 1. 首先检查URL参数
lang_param = request.GET.get('lang')
if lang_param and lang_param in self.supported_languages:
return lang_param
# 2. 然后检查Cookie
lang_cookie = request.COOKIES.get(settings.LANGUAGE_COOKIE_NAME)
if lang_cookie and lang_cookie in self.supported_languages:
return lang_cookie
# 3. 然后检查Accept-Language头
if request.META.get('HTTP_ACCEPT_LANGUAGE'):
lang_header = request.META['HTTP_ACCEPT_LANGUAGE']
for accepted_lang in lang_header.split(','):
lang_code = accepted_lang.split(';')[0].strip()
# 取语言代码部分 (如 'en-US' → 'en')
primary_lang = lang_code.split('-')[0]
if primary_lang in self.supported_languages:
return primary_lang
# 4. 默认返回设置中的默认语言
return self.default_language
5.2 URL国际化中间件
创建基于URL的语言切换中间件:
from django.utils import translation
from django.conf import settings
from django.urls import is_valid_path, get_resolver
from django.http import HttpResponsePermanentRedirect
class URLLanguageMiddleware:
def __init__(self, get_response):
self.get_response = get_response
self.default_language = settings.LANGUAGE_CODE
self.supported_languages = [lang[0] for lang in settings.LANGUAGES]
# 不需要添加语言前缀的URL路径
self.excluded_paths = getattr(settings, 'LANGUAGE_URL_EXCLUDED_PATHS', [
'/media/', '/static/', '/admin/jsi18n/'
])
def __call__(self, request):
# 从URL获取语言
url_language = self.get_language_from_url(request.path_info)
# 设置检测到的语言
if url_language and url_language in self.supported_languages:
translation.activate(url_language)
request.LANGUAGE_CODE = url_language
else:
# 如果URL中没有语言代码且需要添加
if (
not self.is_excluded_path(request.path_info) and
not request.path_info.startswith(f'/{self.default_language}/')
):
# 重定向到带默认语言的URL
new_path = f'/{self.default_language}{request.path_info}'
if request.META.get('QUERY_STRING'):
new_path += f'?{request.META["QUERY_STRING"]}'
return HttpResponsePermanentRedirect(new_path)
# 处理请求
return self.get_response(request)
def get_language_from_url(self, path):
"""从URL路径中提取语言代码"""
if path.startswith('/'):
path = path[1:]
parts = path.split('/')
if parts and parts[0] in self.supported_languages:
return parts[0]
return None
def is_excluded_path(self, path):
"""检查路径是否在排除列表中"""
return any(path.startswith(excluded) for excluded in self.excluded_paths)
6. 安全相关中间件
6.1 防止XSS攻击的中间件
创建XSS防护中间件:
import re
import html
from django.http import HttpResponseBadRequest
class XSSProtectionMiddleware:
def __init__(self, get_response):
self.get_response = get_response
# 配置参数
from django.conf import settings
self.enable_xss_protection = getattr(settings, 'ENABLE_XSS_PROTECTION', True)
self.exempt_paths = getattr(settings, 'XSS_EXEMPT_PATHS', ['/admin/'])
# XSS攻击特征模式
self.xss_patterns = [
r'<script.*?>',
r'javascript:',
r'onerror=',
r'onload=',
r'eval\(',
r'document\.cookie',
r'alert\('
]
self.compiled_patterns = [re.compile(pattern, re.IGNORECASE) for pattern in self.xss_patterns]
def __call__(self, request):
if self.enable_xss_protection and not self.is_exempt_path(request.path):
# 检查GET参数
for key, value in request.GET.items():
if self.contains_xss(value):
return self.xss_detected_response(request, 'GET', key, value)
# 检查POST参数
for key, value in request.POST.items():
if isinstance(value, str) and self.contains_xss(value):
return self.xss_detected_response(request, 'POST', key, value)
response = self.get_response(request)
# 添加安全头
response['X-XSS-Protection'] = '1; mode=block'
return response
def contains_xss(self, value):
"""检查值是否包含XSS攻击特征"""
if not isinstance(value, str):
return False
for pattern in self.compiled_patterns:
if pattern.search(value):
return True
return False
def is_exempt_path(self, path):
"""检查路径是否在豁免列表中"""
return any(path.startswith(exempt) for exempt in self.exempt_paths)
def xss_detected_response(self, request, method, key, value):
"""当检测到XSS攻击时的响应"""
import logging
logger = logging.getLogger('security')
# 记录攻击尝试
logger.warning(
f"Potential XSS attack detected: {method} parameter '{key}' "
f"with value '{html.escape(value)}' from {request.META.get('REMOTE_ADDR')}"
)
# 返回错误响应
return HttpResponseBadRequest("Potential security threat detected.")
6.2 内容安全策略中间件
实现动态CSP策略中间件:
from django.conf import settings
from django.utils.crypto import get_random_string
class CSPMiddleware:
def __init__(self, get_response):
self.get_response = get_response
# 从设置加载CSP策略
self.csp_enabled = getattr(settings, 'CSP_ENABLED', True)
self.csp_report_only = getattr(settings, 'CSP_REPORT_ONLY', False)
self.csp_report_uri = getattr(settings, 'CSP_REPORT_URI', None)
# 默认策略
self.default_csp = {
'default-src': ["'self'"],
'script-src': ["'self'"],
'style-src': ["'self'"],
'img-src': ["'self'", 'data:'],
'font-src': ["'self'"],
'connect-src': ["'self'"],
'frame-src': ["'none'"],
'object-src': ["'none'"],
'base-uri': ["'self'"]
}
# 合并用户配置
self.csp_policies = getattr(settings, 'CSP_POLICIES', {})
for key, value in self.csp_policies.items():
if key in self.default_csp:
self.default_csp[key].extend(value)
else:
self.default_csp[key] = value
def __call__(self, request):
response = self.get_response(request)
if self.csp_enabled:
# 生成nonce用于内联脚本
nonce = get_random_string(16)
request.csp_nonce = nonce
# 构建CSP头
csp = self.build_csp_header(nonce)
# 设置CSP头
header_name = 'Content-Security-Policy-Report-Only' if self.csp_report_only else 'Content-Security-Policy'
response[header_name] = csp
return response
def build_csp_header(self, nonce):
"""构建CSP头值"""
policies = []
# 添加nonce到script-src和style-src
if 'script-src' in self.default_csp:
self.default_csp['script-src'].append(f"'nonce-{nonce}'")
if 'style-src' in self.default_csp:
self.default_csp['style-src'].append(f"'nonce-{nonce}'")
# 构建策略字符串
for directive, sources in self.default_csp.items():
policies.append(f"{directive} {' '.join(sources)}")
# 添加报告URI
if self.csp_report_uri:
policies.append(f"report-uri {self.csp_report_uri}")
return "; ".join(policies)
6.3 点击劫持保护中间件
class ClickjackingProtectionMiddleware:
def __init__(self, get_response):
self.get_response = get_response
# 配置参数
from django.conf import settings
self.exempt_paths = getattr(settings, 'CLICKJACKING_EXEMPT_PATHS', [])
self.frame_options = getattr(settings, 'X_FRAME_OPTIONS', 'DENY') # 'DENY', 'SAMEORIGIN' 或 'ALLOW-FROM uri'
def __call__(self, request):
response = self.get_response(request)
# 检查路径是否在豁免列表中
if not any(request.path.startswith(path) for path in self.exempt_paths):
# 添加X-Frame-Options头
response['X-Frame-Options'] = self.frame_options
return response
7. 缓存中间件实现
7.1 基本页面缓存中间件
创建页面缓存中间件:
from django.core.cache import cache
from django.utils.cache import get_cache_key, learn_cache_key
from django.utils.deprecation import MiddlewareMixin
from django.conf import settings
class PageCacheMiddleware(MiddlewareMixin):
def process_request(self, request):
"""尝试从缓存获取响应"""
# 只缓存GET请求
if request.method != 'GET':
return None
# 不缓存已认证用户的请求
if request.user.is_authenticated:
return None
# 生成缓存键
cache_key = self.get_cache_key(request)
if cache_key is None:
return None
# 尝试从缓存读取
response = cache.get(cache_key)
if response is None:
return None
return response
def process_response(self, request, response):
"""缓存成功的GET响应"""
# 只缓存GET请求和成功的响应
if (
request.method != 'GET' or
response.status_code != 200 or
request.user.is_authenticated or
not self.should_cache_response(request, response)
):
return response
# 生成并设置缓存键
cache_key = learn_cache_key(request, response)
if hasattr(response, 'render') and callable(response.render):
# 处理模板响应
response.add_post_render_callback(
lambda rendered_response: cache.set(
cache_key, rendered_response, self.get_cache_timeout(request)
)
)
else:
# 处理普通响应
cache.set(cache_key, response, self.get_cache_timeout(request))
return response
def get_cache_key(self, request):
"""生成缓存键"""
return get_cache_key(request)
def get_cache_timeout(self, request):
"""获取缓存超时时间"""
# 不同URL可以有不同的缓存超时时间
path = request.path
if path.startswith('/blog/'):
return 60 * 60 # 博客页面缓存1小时
elif path.startswith('/product/'):
return 60 * 15 # 产品页面缓存15分钟
else:
return getattr(settings, 'PAGE_CACHE_SECONDS', 60 * 5) # 默认5分钟
def should_cache_response(self, request, response):
"""决定是否应该缓存响应"""
# 不缓存带有特定Cookie的请求
if request.COOKIES.get('no_cache'):
return False
# 不缓存某些路径
for prefix in getattr(settings, 'CACHE_EXCLUDED_PATHS', []):
if request.path.startswith(prefix):
return False
# 不缓存包含特定头的响应
if response.get('Cache-Control') == 'no-cache':
return False
return True
7.2 分层缓存中间件
实现多级缓存策略:
from django.core.cache import cache, caches
from django.utils.cache import get_cache_key
from django.utils.deprecation import MiddlewareMixin
import hashlib
class TieredCacheMiddleware(MiddlewareMixin):
def __init__(self, get_response=None):
super().__init__(get_response)
# 配置不同的缓存后端
self.memory_cache = caches['memory'] # 内存缓存
self.disk_cache = caches['disk'] # 磁盘缓存
self.redis_cache = cache # Redis缓存
# 缓存时间配置(秒)
self.memory_timeout = 60 # 内存缓存1分钟
self.disk_timeout = 60 * 15 # 磁盘缓存15分钟
self.redis_timeout = 60 * 60 * 4 # Redis缓存4小时
def process_request(self, request):
"""尝试从不同层级的缓存获取响应"""
if request.method != 'GET':
return None
cache_key = self.get_cache_key(request)
if not cache_key:
return None
# 首先尝试内存缓存(最快)
response = self.memory_cache.get(cache_key)
if response:
return response
# 然后尝试磁盘缓存
response = self.disk_cache.get(cache_key)
if response:
# 填充内存缓存
self.memory_cache.set(cache_key, response, self.memory_timeout)
return response
# 最后尝试Redis缓存
response = self.redis_cache.get(cache_key)
if response:
# 填充内存和磁盘缓存
self.memory_cache.set(cache_key, response, self.memory_timeout)
self.disk_cache.set(cache_key, response, self.disk_timeout)
return response
return None
def process_response(self, request, response):
"""将响应保存到多个缓存层"""
if (
request.method != 'GET' or
response.status_code != 200 or
not self.should_cache_response(request, response)
):
return response
cache_key = self.get_cache_key(request)
if not cache_key:
return response
# 为不同类型的响应设置不同的缓存策略
content_type = response.get('Content-Type', '')
is_html = 'text/html' in content_type
is_api = 'application/json' in content_type
# 根据响应类型设置缓存
if is_api:
# API响应:使用短期缓存
self.memory_cache.set(cache_key, response, 30) # 30秒
self.redis_cache.set(cache_key, response, 60) # 1分钟
elif is_html:
# HTML响应:使用中期缓存
self.memory_cache.set(cache_key, response, self.memory_timeout)
self.disk_cache.set(cache_key, response, self.disk_timeout)
# 静态页面可以缓存更长时间
if 'static-page' in request.GET:
self.redis_cache.set(cache_key, response, self.redis_timeout)
else:
# 其他类型响应:常规缓存策略
self.memory_cache.set(cache_key, response, self.memory_timeout)
self.disk_cache.set(cache_key, response, self.disk_timeout)
self.redis_cache.set(cache_key, response, self.redis_timeout)
return response
def get_cache_key(self, request):
"""生成缓存键"""
key = get_cache_key(request)
if not key:
# 备选方法:生成自定义缓存键
key_parts = [
request.path,
request.META.get('QUERY_STRING', ''),
request.COOKIES.get('session_variant', ''), # 用于A/B测试
]
key = hashlib.md5(''.join(key_parts).encode()).hexdigest()
return key
def should_cache_response(self, request, response):
"""决定是否应该缓存响应"""
# 实现缓存决策逻辑...
return True
8. 调试与监控中间件
8.1 SQL查询监控中间件
创建SQL查询监控中间件:
import time
import json
import logging
from django.db import connection
from django.conf import settings
logger = logging.getLogger('sql_performance')
class SQLMonitoringMiddleware:
def __init__(self, get_response):
self.get_response = get_response
self.slow_query_threshold = getattr(settings, 'SLOW_QUERY_THRESHOLD', 0.1) # 秒
def __call__(self, request):
# 只在DEBUG模式或明确启用时启用SQL监控
if not (settings.DEBUG or getattr(settings, 'SQL_MONITORING_ENABLED', False)):
return self.get_response(request)
# 记录起始查询数
start_queries = len(connection.queries)
start_time = time.time()
# 处理请求
response = self.get_response(request)
# 计算查询统计信息
end_time = time.time()
duration = end_time - start_time
# 获取请求期间执行的查询
queries = connection.queries[start_queries:]
num_queries = len(queries)
if num_queries > 0:
# 计算查询总时间
query_time = sum(float(q['time']) for q in queries)
# 找出慢查询
slow_queries = [q for q in queries if float(q['time']) > self.slow_query_threshold]
# 记录统计信息
logger.info(
f"Request: {request.method} {request.path} | "
f"Duration: {duration:.4f}s | "
f"Queries: {num_queries} | "
f"Query Time: {query_time:.4f}s | "
f"Slow Queries: {len(slow_queries)}"
)
# 记录慢查询详情
for i, query in enumerate(slow_queries):
logger.warning(
f"Slow Query #{i+1} ({query['time']}s): {query['sql']}"
)
# 在响应中添加SQL统计信息(仅在DEBUG模式)
if settings.DEBUG and request.headers.get('X-Requested-With') == 'XMLHttpRequest':
# 对于AJAX请求,返回JSON响应中添加调试信息
if 'application/json' in response.get('Content-Type', ''):
try:
data = json.loads(response.content.decode('utf-8'))
# 添加调试信息
if isinstance(data, dict):
data['__debug'] = {
'duration': duration,
'queries': num_queries,
'query_time': query_time,
'slow_queries': len(slow_queries)
}
# 更新响应内容
response.content = json.dumps(data).encode('utf-8')
response['Content-Length'] = str(len(response.content))
except (json.JSONDecodeError, UnicodeDecodeError
## 4. 中间件应用场景与实战解析
### 4.1 性能监控中间件
监控请求执行时间是优化应用性能的第一步。这个中间件可以帮助你识别慢请求:
```python
import time
import logging
from django.conf import settings
logger = logging.getLogger('performance')
class PerformanceMonitorMiddleware:
def __init__(self, get_response):
self.get_response = get_response
# 设置阈值,超过该时间(ms)的请求将被记录
self.threshold = getattr(settings, 'SLOW_REQUEST_THRESHOLD', 500)
def __call__(self, request):
start_time = time.time()
response = self.get_response(request)
duration = (time.time() - start_time) * 1000 # 转换为毫秒
if duration > self.threshold:
logger.warning(
f'慢请求: {request.method} {request.path} 耗时 {duration:.2f}ms'
)
return response
在生产环境中,你可以配合ELK或Prometheus等监控系统,实现性能数据可视化和告警。
4.2 IP限流中间件
防止恶意请求和DOS攻击的简单限流实现:
from django.core.cache import cache
from django.http import HttpResponse
import time
class RateLimitMiddleware:
def __init__(self, get_response):
self.get_response = get_response
self.rate_limit = 60 # 每分钟最大请求数
self.window = 60 # 时间窗口(秒)
def __call__(self, request):
# 获取客户端IP
ip = self.get_client_ip(request)
cache_key = f'rate_limit:{ip}'
# 获取当前请求计数
requests = cache.get(cache_key, [])
now = time.time()
# 仅保留时间窗口内的请求记录
requests = [t for t in requests if now - t < self.window]
# 检查是否超过限制
if len(requests) >= self.rate_limit:
return HttpResponse('请求频率过高,请稍后再试', status=429)
# 添加当前请求时间
requests.append(now)
cache.set(cache_key, requests, self.window)
# 继续处理请求
return self.get_response(request)
def get_client_ip(self, request):
x_forwarded_for = request.META.get('HTTP_X_FORWARDED_FOR')
if x_forwarded_for:
ip = x_forwarded_for.split(',')[0]
else:
ip = request.META.get('REMOTE_ADDR')
return ip
4.3 跨域资源共享(CORS)中间件
实现自定义CORS策略(而不是使用django-cors-headers):
class CorsMiddleware:
def __init__(self, get_response):
self.get_response = get_response
def __call__(self, request):
response = self.get_response(request)
# 添加CORS头信息
response["Access-Control-Allow-Origin"] = "https://trusted-site.com"
response["Access-Control-Allow-Methods"] = "GET, POST, OPTIONS"
response["Access-Control-Allow-Headers"] = "Content-Type, Authorization"
response["Access-Control-Allow-Credentials"] = "true"
return response
def process_view(self, request, view_func, view_args, view_kwargs):
# 处理预检请求
if request.method == "OPTIONS":
response = HttpResponse()
response["Access-Control-Allow-Origin"] = "https://trusted-site.com"
response["Access-Control-Allow-Methods"] = "GET, POST, OPTIONS"
response["Access-Control-Allow-Headers"] = "Content-Type, Authorization"
response["Access-Control-Allow-Credentials"] = "true"
response["Access-Control-Max-Age"] = "86400" # 24小时
return response
return None
5. 中间件性能与优化
中间件在每个请求中都会执行,因此它们的性能直接影响整个应用的响应时间。
5.1 性能优化策略
- 减少处理逻辑:中间件应专注于其核心功能,避免复杂计算
- 使用缓存:频繁访问的数据应当缓存
- 延迟导入:在
__init__
方法中导入依赖,而不是模块级别 - 条件执行:通过URL路径或请求类型过滤,只在必要时处理
class ConditionalMiddleware:
def __init__(self, get_response):
self.get_response = get_response
# 延迟导入依赖
import re
self.re = re
# 定义白名单
self.url_patterns = [
self.re.compile(r'^/api/'),
self.re.compile(r'^/admin/'),
]
def __call__(self, request):
# 条件执行
if self.should_process(request):
# 处理逻辑
pass
return self.get_response(request)
def should_process(self, request):
# 检查是否匹配白名单路径
path = request.path
return any(pattern.match(path) for pattern in self.url_patterns)
6. 中间件测试策略
中间件测试是确保它们正常工作的关键。以下是测试中间件的最佳实践:
6.1 单元测试
from django.test import TestCase, RequestFactory
from django.http import HttpResponse
from myapp.middleware import PerformanceMonitorMiddleware
class PerformanceMiddlewareTest(TestCase):
def setUp(self):
self.factory = RequestFactory()
# 创建一个模拟的get_response函数
def get_response(request):
response = HttpResponse("测试响应")
return response
self.middleware = PerformanceMonitorMiddleware(get_response)
def test_slow_request_is_logged(self):
# 创建请求
request = self.factory.get('/test/')
# 模拟慢响应
import time
original_time = time.time
try:
# 模拟时间流逝
mock_time_values = [100.0, 101.0] # 模拟1秒(1000ms)的请求
time.time = lambda: mock_time_values.pop(0)
with self.assertLogs('performance', level='WARNING') as cm:
response = self.middleware(request)
self.assertIn('慢请求', cm.output[0])
finally:
# 恢复原始函数
time.time = original_time
6.2 集成测试
from django.test import TestCase, Client
from django.urls import reverse
class RateLimitMiddlewareIntegrationTest(TestCase):
def setUp(self):
self.client = Client()
def test_rate_limit_exceeded(self):
url = reverse('api-endpoint')
# 发送超过限制的请求
for _ in range(61): # 假设限制为60
response = self.client.get(url)
# 验证最后一个请求被限制
self.assertEqual(response.status_code, 429)
self.assertIn('请求频率过高', response.content.decode())
7. 与内置中间件协同工作
Django有许多内置中间件,如何与它们协同工作是中间件开发的重要方面。
7.1 中间件顺序
在settings.py
中,中间件按照定义顺序执行:
MIDDLEWARE = [
'django.middleware.security.SecurityMiddleware',
'django.contrib.sessions.middleware.SessionMiddleware',
'myproject.middleware.custom_middleware.CustomMiddleware', # 自定义中间件
'django.middleware.common.CommonMiddleware',
'django.middleware.csrf.CsrfViewMiddleware',
'django.contrib.auth.middleware.AuthenticationMiddleware',
'django.contrib.messages.middleware.MessageMiddleware',
'django.middleware.clickjacking.XFrameOptionsMiddleware',
]
7.2 与内置中间件交互
class CustomAuthMiddleware:
def __init__(self, get_response):
self.get_response = get_response
def __call__(self, request):
# 需要在SessionMiddleware和AuthenticationMiddleware之后
# 才能访问request.user和request.session
if not request.user.is_authenticated:
# 自定义认证逻辑
pass
return self.get_response(request)
8. 总结与最佳实践
开发Django中间件时,牢记以下最佳实践:
- 职责单一:每个中间件只负责一个明确的功能
- 性能优先:中间件执行频率高,必须高效
- 错误处理:妥善处理异常,不影响请求流程
- 测试覆盖:编写全面的测试确保正常工作
- 文档完善:为团队成员提供清晰的使用指南
- 配置灵活:通过settings提供可配置选项
- 兼容性考虑:特别是与内置中间件的交互
- 版本迁移:确保在Django版本升级时平滑过渡
Django中间件是构建健壮、可维护Web应用的强大工具。通过本文介绍的模式和实践,你可以充分发挥中间件的潜力,构建出更加安全、高效、易于维护的Django应用。
在下一篇文章中,我们将深入探讨Django的模板系统,带你掌握如何构建灵活且高效的前端展示层。