06【核心精通】Django中间件开发与应用:10个高级技巧打造强大Web应用

【核心精通】Django中间件开发与应用:10个高级技巧打造强大Web应用

前言:中间件如何成为Django应用的隐形强者?

在Django的架构中,中间件(Middleware)扮演着至关重要的角色,它们像无形的守卫一样拦截请求和响应,可以在不修改现有视图代码的情况下为整个应用添加功能。根据统计,超过90%的Django项目都使用了至少3个自定义中间件来解决跨领域关注点,如安全、性能监控和用户跟踪等。然而,大多数开发者对中间件的了解仅限于使用内置组件,而未充分发挥其潜力。本文将深入探讨Django中间件的开发与应用,揭示如何通过10个高级技巧创建强大且高效的中间件组件,为你的Django项目增添新的维度。

1. 中间件基础概念

1.1 什么是Django中间件?

中间件是Django请求/响应处理流程中的钩子机制,能够处理全局性的问题,如安全检查、会话管理、用户认证等。它们被配置在应用的请求处理管道中,按顺序处理每个请求和响应。

# settings.py中的中间件配置
MIDDLEWARE = [
    'django.middleware.security.SecurityMiddleware',
    'django.contrib.sessions.middleware.SessionMiddleware',
    'django.middleware.common.CommonMiddleware',
    'django.middleware.csrf.CsrfViewMiddleware',
    'django.contrib.auth.middleware.AuthenticationMiddleware',
    'django.contrib.messages.middleware.MessageMiddleware',
    'django.middleware.clickjacking.XFrameOptionsMiddleware',
    # 自定义中间件
    'myapp.middleware.custom_middleware.CustomMiddleware',
]

1.2 请求/响应流程与中间件执行顺序

中间件在请求/响应流程中的执行顺序非常重要:

  1. 请求从顶到底通过中间件栈
  2. 到达视图函数
  3. 响应从底到顶返回通过中间件栈
请求处理流程:
Browser → [Security] → [Session] → [CSRF] → [Auth] → ... → View
响应处理流程:
View → ... → [Auth] → [CSRF] → [Session] → [Security] → Browser

中间件顺序会影响功能,例如SessionMiddleware必须在AuthenticationMiddleware之前,因为后者依赖于前者创建的会话。

1.3 中间件类型与钩子方法

Django中间件类可以定义以下方法:

class SimpleMiddleware:
    def __init__(self, get_response):
        self.get_response = get_response
        # 一次性配置和初始化

    def __call__(self, request):
        # 处理视图调用前的请求
        
        response = self.get_response(request)
        
        # 处理视图调用后的响应
        
        return response
    
    def process_view(self, request, view_func, view_args, view_kwargs):
        # 在视图函数调用前执行
        return None
    
    def process_exception(self, request, exception):
        # 视图引发异常时调用
        return None
    
    def process_template_response(self, request, response):
        # 对模板响应进行处理
        return response

2. 创建第一个中间件

2.1 中间件基本结构

让我们创建一个简单的中间件,用于记录请求处理时间:

# myapp/middleware/timing.py
import time
import logging

logger = logging.getLogger(__name__)

class RequestTimingMiddleware:
    def __init__(self, get_response):
        self.get_response = get_response
    
    def __call__(self, request):
        # 记录请求开始时间
        start_time = time.time()
        
        # 调用视图
        response = self.get_response(request)
        
        # 计算请求处理时间
        duration = time.time() - start_time
        logger.info(f"Request to {request.path} took {duration:.4f}s")
        
        # 添加处理时间到响应头
        response['X-Request-Time'] = str(duration)
        
        return response

settings.py中启用中间件:

MIDDLEWARE = [
    # ...其他中间件
    'myapp.middleware.timing.RequestTimingMiddleware',
]

2.2 同步和异步中间件

从Django 3.1开始,中间件支持异步处理:

# 同步中间件
class SyncMiddleware:
    def __init__(self, get_response):
        self.get_response = get_response
    
    def __call__(self, request):
        # 同步处理
        return self.get_response(request)

# 异步中间件
class AsyncMiddleware:
    def __init__(self, get_response):
        self.get_response = get_response
    
    async def __call__(self, request):
        # 异步处理请求
        response = await self.get_response(request)
        return response

混合环境中的兼容性处理:

import asyncio

class CompatibleMiddleware:
    def __init__(self, get_response):
        self.get_response = get_response
        
        # 确定get_response是同步或异步
        self.is_async = asyncio.iscoroutinefunction(get_response)
    
    async def __call__(self, request):
        if self.is_async:
            # 异步处理
            return await self.get_response(request)
        else:
            # 同步处理在异步环境下
            return await asyncio.to_thread(self.get_response, request)

2.3 中间件的启用与禁用

根据条件动态启用/禁用中间件:

class ConditionalMiddleware:
    def __init__(self, get_response):
        self.get_response = get_response
        
        # 从配置决定是否启用
        from django.conf import settings
        self.enabled = getattr(settings, 'ENABLE_CUSTOM_MIDDLEWARE', True)
    
    def __call__(self, request):
        if not self.enabled:
            # 中间件被禁用,直接传递请求
            return self.get_response(request)
        
        # 中间件启用,执行额外逻辑
        # ...
        
        response = self.get_response(request)
        
        # 处理响应
        # ...
        
        return response

3. 常见中间件应用场景

3.1 用户认证与授权中间件

实现基于JWT的认证中间件:

# myapp/middleware/jwt_auth.py
import jwt
from django.contrib.auth.models import AnonymousUser
from django.conf import settings
from django.utils.functional import SimpleLazyObject
from myapp.models import User

def get_user_from_token(request):
    # 从请求头获取令牌
    auth_header = request.headers.get('Authorization', '')
    if not auth_header.startswith('Bearer '):
        return AnonymousUser()
    
    token = auth_header.split(' ')[1]
    
    try:
        # 验证令牌
        payload = jwt.decode(
            token, 
            settings.JWT_SECRET_KEY,
            algorithms=['HS256']
        )
        
        # 获取用户
        user_id = payload.get('user_id')
        if user_id:
            return User.objects.get(id=user_id)
        
    except (jwt.InvalidTokenError, User.DoesNotExist):
        pass
    
    return AnonymousUser()

class JWTAuthMiddleware:
    def __init__(self, get_response):
        self.get_response = get_response
    
    def __call__(self, request):
        # 使用SimpleLazyObject延迟加载用户
        # 只有在需要时才验证令牌和查询数据库
        request.user = SimpleLazyObject(lambda: get_user_from_token(request))
        
        return self.get_response(request)

3.2 请求/响应修改中间件

添加安全头的中间件:

class SecurityHeadersMiddleware:
    def __init__(self, get_response):
        self.get_response = get_response
    
    def __call__(self, request):
        response = self.get_response(request)
        
        # 添加安全相关的HTTP头
        response['Content-Security-Policy'] = "default-src 'self'"
        response['X-Content-Type-Options'] = 'nosniff'
        response['X-Frame-Options'] = 'DENY'
        response['X-XSS-Protection'] = '1; mode=block'
        response['Referrer-Policy'] = 'strict-origin-when-cross-origin'
        
        # 仅在HTTPS时设置
        if request.is_secure():
            age = 60 * 60 * 24 * 365  # 1年
            response['Strict-Transport-Security'] = f'max-age={age}; includeSubDomains'
        
        return response

3.3 异常处理中间件

创建全局异常处理中间件:

import json
import traceback
import logging
from django.http import JsonResponse, HttpResponse
from django.conf import settings

logger = logging.getLogger(__name__)

class GlobalExceptionMiddleware:
    def __init__(self, get_response):
        self.get_response = get_response
    
    def __call__(self, request):
        try:
            return self.get_response(request)
        except Exception as exc:
            return self.handle_exception(request, exc)
    
    def handle_exception(self, request, exception):
        # 记录异常
        logger.error(
            f"Unhandled exception: {str(exception)}\n"
            f"URL: {request.path}\n"
            f"{traceback.format_exc()}"
        )
        
        # 确定是否为API请求
        is_api_request = request.path.startswith('/api/') or \
                         request.headers.get('Accept') == 'application/json'
        
        # API请求返回JSON响应
        if is_api_request:
            return self.handle_api_exception(request, exception)
        
        # 常规请求渲染错误页面
        return self.handle_html_exception(request, exception)
    
    def handle_api_exception(self, request, exception):
        status_code = getattr(exception, 'status_code', 500)
        
        data = {
            'error': True,
            'message': str(exception),
            'detail': str(exception) if settings.DEBUG else None,
            'type': exception.__class__.__name__
        }
        
        return JsonResponse(data, status=status_code)
    
    def handle_html_exception(self, request, exception):
        if settings.DEBUG:
            # 在DEBUG模式下重新抛出异常,显示详细信息
            raise
        
        # 生产环境返回友好的错误页面
        from django.template.response import TemplateResponse
        return TemplateResponse(
            request,
            'errors/500.html',
            context={'error_message': str(exception)},
            status=500
        )

3.4 性能监控中间件

创建性能监控中间件:

import time
import uuid
from django.db import connection
from django.conf import settings

class PerformanceMonitoringMiddleware:
    def __init__(self, get_response):
        self.get_response = get_response
    
    def __call__(self, request):
        # 生成请求ID,用于跟踪整个请求过程
        request_id = str(uuid.uuid4())
        request.request_id = request_id
        
        # 记录开始时间和数据库查询次数
        start_time = time.time()
        start_queries = len(connection.queries)
        
        # 处理请求
        response = self.get_response(request)
        
        # 计算处理时间和数据库查询统计
        duration = time.time() - start_time
        db_queries = len(connection.queries) - start_queries
        
        # 将统计信息添加到响应头
        response['X-Request-ID'] = request_id
        response['X-Request-Time'] = f"{duration:.4f}s"
        response['X-DB-Queries'] = str(db_queries)
        
        # 记录慢请求
        threshold = getattr(settings, 'SLOW_REQUEST_THRESHOLD', 1.0)
        if duration > threshold:
            self.log_slow_request(request, duration, db_queries)
        
        return response
    
    def log_slow_request(self, request, duration, db_queries):
        import logging
        logger = logging.getLogger('performance')
        
        logger.warning(
            f"Slow request detected: {request.method} {request.path} "
            f"took {duration:.4f}s with {db_queries} DB queries. "
            f"Request ID: {request.request_id}"
        )
        
        # 在DEBUG模式下记录数据库查询详情
        if settings.DEBUG:
            for i, query in enumerate(connection.queries[-db_queries:]):
                logger.debug(f"Query {i+1}: {query['sql']} ({query['time']}s)")

4. 高级中间件技巧

4.1 中间件链通信

使用请求对象在中间件之间传递数据:

class FirstMiddleware:
    def __init__(self, get_response):
        self.get_response = get_response
    
    def __call__(self, request):
        # 在请求对象上设置属性
        request.first_middleware_data = "Data from first middleware"
        
        # 处理请求
        response = self.get_response(request)
        
        # 从响应中读取其他中间件设置的属性
        second_data = getattr(response, 'second_middleware_data', None)
        print(f"FirstMiddleware received: {second_data}")
        
        return response

class SecondMiddleware:
    def __init__(self, get_response):
        self.get_response = get_response
    
    def __call__(self, request):
        # 读取第一个中间件设置的属性
        first_data = getattr(request, 'first_middleware_data', None)
        print(f"SecondMiddleware received: {first_data}")
        
        # 处理请求
        response = self.get_response(request)
        
        # 在响应对象上设置属性
        response.second_middleware_data = "Data from second middleware"
        
        return response

4.2 基于请求类型的条件处理

根据请求类型应用不同逻辑:

class SmartProcessingMiddleware:
    def __init__(self, get_response):
        self.get_response = get_response
    
    def __call__(self, request):
        # 根据请求路径决定处理方式
        if request.path.startswith('/api/'):
            return self.process_api_request(request)
        elif request.path.startswith('/admin/'):
            return self.process_admin_request(request)
        else:
            return self.process_regular_request(request)
    
    def process_api_request(self, request):
        # API请求的处理逻辑
        start_time = time.time()
        response = self.get_response(request)
        response['X-API-Time'] = f"{time.time() - start_time:.4f}s"
        return response
    
    def process_admin_request(self, request):
        # 管理界面请求的处理逻辑
        if not request.user.is_staff:
            from django.contrib.auth.views import redirect_to_login
            return redirect_to_login(request.get_full_path())
        
        return self.get_response(request)
    
    def process_regular_request(self, request):
        # 普通请求的处理逻辑
        return self.get_response(request)

4.3 实现请求速率限制

创建IP地址速率限制中间件:

import time
from django.core.cache import cache
from django.http import HttpResponse

class RateLimitMiddleware:
    def __init__(self, get_response):
        self.get_response = get_response
        
        # 配置参数
        from django.conf import settings
        self.rate_limit = getattr(settings, 'RATE_LIMIT', 100)  # 请求/分钟
        self.window = getattr(settings, 'RATE_LIMIT_WINDOW', 60)  # 窗口期秒数
        self.exempt_paths = getattr(settings, 'RATE_LIMIT_EXEMPT', ['/health/', '/ping/'])
    
    def __call__(self, request):
        # 对特定路径豁免
        if any(request.path.startswith(path) for path in self.exempt_paths):
            return self.get_response(request)
        
        # 获取客户端IP
        ip = self.get_client_ip(request)
        
        # 检查速率限制
        if self.is_rate_limited(ip):
            return HttpResponse(
                "Rate limit exceeded. Please try again later.",
                status=429
            )
        
        # 正常处理请求
        return self.get_response(request)
    
    def get_client_ip(self, request):
        x_forwarded_for = request.META.get('HTTP_X_FORWARDED_FOR')
        if x_forwarded_for:
            ip = x_forwarded_for.split(',')[0].strip()
        else:
            ip = request.META.get('REMOTE_ADDR')
        return ip
    
    def is_rate_limited(self, ip):
        cache_key = f"rate_limit:{ip}"
        
        # 获取当前计数
        requests = cache.get(cache_key, [])
        
        # 清除窗口期外的请求
        current_time = time.time()
        window_start = current_time - self.window
        requests = [r for r in requests if r > window_start]
        
        # 检查请求数量是否超过限制
        if len(requests) >= self.rate_limit:
            return True
        
        # 添加新请求时间戳并更新缓存
        requests.append(current_time)
        cache.set(cache_key, requests, self.window)
        
        return False

4.4 请求/响应修改中间件

实现请求体和响应体修改的中间件:

import json
from django.http import HttpResponse
from django.utils.deprecation import MiddlewareMixin

class RequestBodyModifierMiddleware(MiddlewareMixin):
    """修改传入的JSON请求体"""
    
    def process_request(self, request):
        if request.content_type == 'application/json' and request.body:
            try:
                # 解析JSON请求体
                body_data = json.loads(request.body)
                
                # 修改数据
                if 'timestamp' not in body_data:
                    body_data['timestamp'] = int(time.time())
                
                # 将修改后的数据保存回请求对象
                request._body = json.dumps(body_data).encode('utf-8')
                
            except json.JSONDecodeError:
                # 非有效JSON,不做修改
                pass
        
        return None

class ResponseBodyModifierMiddleware:
    """修改输出的JSON响应体"""
    
    def __init__(self, get_response):
        self.get_response = get_response
    
    def __call__(self, request):
        response = self.get_response(request)
        
        # 仅处理JSON响应
        if response.get('Content-Type') == 'application/json':
            self.modify_json_response(response)
        
        return response
    
    def modify_json_response(self, response):
        try:
            # 解析响应内容
            content = json.loads(response.content.decode('utf-8'))
            
            # 添加元数据
            content['meta'] = {
                'version': '1.0',
                'timestamp': int(time.time())
            }
            
            # 更新响应内容
            response.content = json.dumps(content).encode('utf-8')
            response['Content-Length'] = str(len(response.content))
            
        except json.JSONDecodeError:
            # 非有效JSON,不做修改
            pass

5. 处理多语言和本地化

5.1 语言检测与设置中间件

创建自动语言检测中间件:

from django.utils import translation
from django.conf import settings

class LanguageMiddleware:
    def __init__(self, get_response):
        self.get_response = get_response
        self.default_language = settings.LANGUAGE_CODE
        self.supported_languages = [lang[0] for lang in settings.LANGUAGES]
    
    def __call__(self, request):
        # 获取优先语言
        language = self.get_language(request)
        
        # 设置当前请求的语言
        translation.activate(language)
        request.LANGUAGE_CODE = language
        
        # 处理请求
        response = self.get_response(request)
        
        # 为响应设置语言Cookie
        response.set_cookie(
            settings.LANGUAGE_COOKIE_NAME, 
            language,
            max_age=settings.LANGUAGE_COOKIE_AGE
        )
        
        return response
    
    def get_language(self, request):
        """按优先级获取语言"""
        
        # 1. 首先检查URL参数
        lang_param = request.GET.get('lang')
        if lang_param and lang_param in self.supported_languages:
            return lang_param
        
        # 2. 然后检查Cookie
        lang_cookie = request.COOKIES.get(settings.LANGUAGE_COOKIE_NAME)
        if lang_cookie and lang_cookie in self.supported_languages:
            return lang_cookie
        
        # 3. 然后检查Accept-Language头
        if request.META.get('HTTP_ACCEPT_LANGUAGE'):
            lang_header = request.META['HTTP_ACCEPT_LANGUAGE']
            for accepted_lang in lang_header.split(','):
                lang_code = accepted_lang.split(';')[0].strip()
                # 取语言代码部分 (如 'en-US' → 'en')
                primary_lang = lang_code.split('-')[0]
                
                if primary_lang in self.supported_languages:
                    return primary_lang
        
        # 4. 默认返回设置中的默认语言
        return self.default_language

5.2 URL国际化中间件

创建基于URL的语言切换中间件:

from django.utils import translation
from django.conf import settings
from django.urls import is_valid_path, get_resolver
from django.http import HttpResponsePermanentRedirect

class URLLanguageMiddleware:
    def __init__(self, get_response):
        self.get_response = get_response
        self.default_language = settings.LANGUAGE_CODE
        self.supported_languages = [lang[0] for lang in settings.LANGUAGES]
        
        # 不需要添加语言前缀的URL路径
        self.excluded_paths = getattr(settings, 'LANGUAGE_URL_EXCLUDED_PATHS', [
            '/media/', '/static/', '/admin/jsi18n/'
        ])
    
    def __call__(self, request):
        # 从URL获取语言
        url_language = self.get_language_from_url(request.path_info)
        
        # 设置检测到的语言
        if url_language and url_language in self.supported_languages:
            translation.activate(url_language)
            request.LANGUAGE_CODE = url_language
        else:
            # 如果URL中没有语言代码且需要添加
            if (
                not self.is_excluded_path(request.path_info) and 
                not request.path_info.startswith(f'/{self.default_language}/')
            ):
                # 重定向到带默认语言的URL
                new_path = f'/{self.default_language}{request.path_info}'
                if request.META.get('QUERY_STRING'):
                    new_path += f'?{request.META["QUERY_STRING"]}'
                
                return HttpResponsePermanentRedirect(new_path)
        
        # 处理请求
        return self.get_response(request)
    
    def get_language_from_url(self, path):
        """从URL路径中提取语言代码"""
        if path.startswith('/'):
            path = path[1:]
        
        parts = path.split('/')
        if parts and parts[0] in self.supported_languages:
            return parts[0]
        
        return None
    
    def is_excluded_path(self, path):
        """检查路径是否在排除列表中"""
        return any(path.startswith(excluded) for excluded in self.excluded_paths)

6. 安全相关中间件

6.1 防止XSS攻击的中间件

创建XSS防护中间件:

import re
import html
from django.http import HttpResponseBadRequest

class XSSProtectionMiddleware:
    def __init__(self, get_response):
        self.get_response = get_response
        
        # 配置参数
        from django.conf import settings
        self.enable_xss_protection = getattr(settings, 'ENABLE_XSS_PROTECTION', True)
        self.exempt_paths = getattr(settings, 'XSS_EXEMPT_PATHS', ['/admin/'])
        
        # XSS攻击特征模式
        self.xss_patterns = [
            r'<script.*?>',
            r'javascript:',
            r'onerror=',
            r'onload=',
            r'eval\(',
            r'document\.cookie',
            r'alert\('
        ]
        self.compiled_patterns = [re.compile(pattern, re.IGNORECASE) for pattern in self.xss_patterns]
    
    def __call__(self, request):
        if self.enable_xss_protection and not self.is_exempt_path(request.path):
            # 检查GET参数
            for key, value in request.GET.items():
                if self.contains_xss(value):
                    return self.xss_detected_response(request, 'GET', key, value)
            
            # 检查POST参数
            for key, value in request.POST.items():
                if isinstance(value, str) and self.contains_xss(value):
                    return self.xss_detected_response(request, 'POST', key, value)
                    
        response = self.get_response(request)
        
        # 添加安全头
        response['X-XSS-Protection'] = '1; mode=block'
        
        return response
    
    def contains_xss(self, value):
        """检查值是否包含XSS攻击特征"""
        if not isinstance(value, str):
            return False
            
        for pattern in self.compiled_patterns:
            if pattern.search(value):
                return True
        
        return False
    
    def is_exempt_path(self, path):
        """检查路径是否在豁免列表中"""
        return any(path.startswith(exempt) for exempt in self.exempt_paths)
    
    def xss_detected_response(self, request, method, key, value):
        """当检测到XSS攻击时的响应"""
        import logging
        logger = logging.getLogger('security')
        
        # 记录攻击尝试
        logger.warning(
            f"Potential XSS attack detected: {method} parameter '{key}' "
            f"with value '{html.escape(value)}' from {request.META.get('REMOTE_ADDR')}"
        )
        
        # 返回错误响应
        return HttpResponseBadRequest("Potential security threat detected.")

6.2 内容安全策略中间件

实现动态CSP策略中间件:

from django.conf import settings
from django.utils.crypto import get_random_string

class CSPMiddleware:
    def __init__(self, get_response):
        self.get_response = get_response
        
        # 从设置加载CSP策略
        self.csp_enabled = getattr(settings, 'CSP_ENABLED', True)
        self.csp_report_only = getattr(settings, 'CSP_REPORT_ONLY', False)
        self.csp_report_uri = getattr(settings, 'CSP_REPORT_URI', None)
        
        # 默认策略
        self.default_csp = {
            'default-src': ["'self'"],
            'script-src': ["'self'"],
            'style-src': ["'self'"],
            'img-src': ["'self'", 'data:'],
            'font-src': ["'self'"],
            'connect-src': ["'self'"],
            'frame-src': ["'none'"],
            'object-src': ["'none'"],
            'base-uri': ["'self'"]
        }
        
        # 合并用户配置
        self.csp_policies = getattr(settings, 'CSP_POLICIES', {})
        for key, value in self.csp_policies.items():
            if key in self.default_csp:
                self.default_csp[key].extend(value)
            else:
                self.default_csp[key] = value
    
    def __call__(self, request):
        response = self.get_response(request)
        
        if self.csp_enabled:
            # 生成nonce用于内联脚本
            nonce = get_random_string(16)
            request.csp_nonce = nonce
            
            # 构建CSP头
            csp = self.build_csp_header(nonce)
            
            # 设置CSP头
            header_name = 'Content-Security-Policy-Report-Only' if self.csp_report_only else 'Content-Security-Policy'
            response[header_name] = csp
        
        return response
    
    def build_csp_header(self, nonce):
        """构建CSP头值"""
        policies = []
        
        # 添加nonce到script-src和style-src
        if 'script-src' in self.default_csp:
            self.default_csp['script-src'].append(f"'nonce-{nonce}'")
        
        if 'style-src' in self.default_csp:
            self.default_csp['style-src'].append(f"'nonce-{nonce}'")
        
        # 构建策略字符串
        for directive, sources in self.default_csp.items():
            policies.append(f"{directive} {' '.join(sources)}")
        
        # 添加报告URI
        if self.csp_report_uri:
            policies.append(f"report-uri {self.csp_report_uri}")
        
        return "; ".join(policies)

6.3 点击劫持保护中间件

class ClickjackingProtectionMiddleware:
    def __init__(self, get_response):
        self.get_response = get_response
        
        # 配置参数
        from django.conf import settings
        self.exempt_paths = getattr(settings, 'CLICKJACKING_EXEMPT_PATHS', [])
        self.frame_options = getattr(settings, 'X_FRAME_OPTIONS', 'DENY')  # 'DENY', 'SAMEORIGIN' 或 'ALLOW-FROM uri'
    
    def __call__(self, request):
        response = self.get_response(request)
        
        # 检查路径是否在豁免列表中
        if not any(request.path.startswith(path) for path in self.exempt_paths):
            # 添加X-Frame-Options头
            response['X-Frame-Options'] = self.frame_options
        
        return response

7. 缓存中间件实现

7.1 基本页面缓存中间件

创建页面缓存中间件:

from django.core.cache import cache
from django.utils.cache import get_cache_key, learn_cache_key
from django.utils.deprecation import MiddlewareMixin
from django.conf import settings

class PageCacheMiddleware(MiddlewareMixin):
    def process_request(self, request):
        """尝试从缓存获取响应"""
        # 只缓存GET请求
        if request.method != 'GET':
            return None
        
        # 不缓存已认证用户的请求
        if request.user.is_authenticated:
            return None
        
        # 生成缓存键
        cache_key = self.get_cache_key(request)
        if cache_key is None:
            return None
        
        # 尝试从缓存读取
        response = cache.get(cache_key)
        if response is None:
            return None
        
        return response
    
    def process_response(self, request, response):
        """缓存成功的GET响应"""
        # 只缓存GET请求和成功的响应
        if (
            request.method != 'GET' or 
            response.status_code != 200 or 
            request.user.is_authenticated or
            not self.should_cache_response(request, response)
        ):
            return response
        
        # 生成并设置缓存键
        cache_key = learn_cache_key(request, response)
        if hasattr(response, 'render') and callable(response.render):
            # 处理模板响应
            response.add_post_render_callback(
                lambda rendered_response: cache.set(
                    cache_key, rendered_response, self.get_cache_timeout(request)
                )
            )
        else:
            # 处理普通响应
            cache.set(cache_key, response, self.get_cache_timeout(request))
        
        return response
    
    def get_cache_key(self, request):
        """生成缓存键"""
        return get_cache_key(request)
    
    def get_cache_timeout(self, request):
        """获取缓存超时时间"""
        # 不同URL可以有不同的缓存超时时间
        path = request.path
        if path.startswith('/blog/'):
            return 60 * 60  # 博客页面缓存1小时
        elif path.startswith('/product/'):
            return 60 * 15  # 产品页面缓存15分钟
        else:
            return getattr(settings, 'PAGE_CACHE_SECONDS', 60 * 5)  # 默认5分钟
    
    def should_cache_response(self, request, response):
        """决定是否应该缓存响应"""
        # 不缓存带有特定Cookie的请求
        if request.COOKIES.get('no_cache'):
            return False
        
        # 不缓存某些路径
        for prefix in getattr(settings, 'CACHE_EXCLUDED_PATHS', []):
            if request.path.startswith(prefix):
                return False
        
        # 不缓存包含特定头的响应
        if response.get('Cache-Control') == 'no-cache':
            return False
        
        return True

7.2 分层缓存中间件

实现多级缓存策略:

from django.core.cache import cache, caches
from django.utils.cache import get_cache_key
from django.utils.deprecation import MiddlewareMixin
import hashlib

class TieredCacheMiddleware(MiddlewareMixin):
    def __init__(self, get_response=None):
        super().__init__(get_response)
        
        # 配置不同的缓存后端
        self.memory_cache = caches['memory']  # 内存缓存
        self.disk_cache = caches['disk']      # 磁盘缓存
        self.redis_cache = cache              # Redis缓存
        
        # 缓存时间配置(秒)
        self.memory_timeout = 60              # 内存缓存1分钟
        self.disk_timeout = 60 * 15           # 磁盘缓存15分钟
        self.redis_timeout = 60 * 60 * 4      # Redis缓存4小时
    
    def process_request(self, request):
        """尝试从不同层级的缓存获取响应"""
        if request.method != 'GET':
            return None
        
        cache_key = self.get_cache_key(request)
        if not cache_key:
            return None
        
        # 首先尝试内存缓存(最快)
        response = self.memory_cache.get(cache_key)
        if response:
            return response
        
        # 然后尝试磁盘缓存
        response = self.disk_cache.get(cache_key)
        if response:
            # 填充内存缓存
            self.memory_cache.set(cache_key, response, self.memory_timeout)
            return response
        
        # 最后尝试Redis缓存
        response = self.redis_cache.get(cache_key)
        if response:
            # 填充内存和磁盘缓存
            self.memory_cache.set(cache_key, response, self.memory_timeout)
            self.disk_cache.set(cache_key, response, self.disk_timeout)
            return response
        
        return None
    
    def process_response(self, request, response):
        """将响应保存到多个缓存层"""
        if (
            request.method != 'GET' or 
            response.status_code != 200 or 
            not self.should_cache_response(request, response)
        ):
            return response
        
        cache_key = self.get_cache_key(request)
        if not cache_key:
            return response
        
        # 为不同类型的响应设置不同的缓存策略
        content_type = response.get('Content-Type', '')
        is_html = 'text/html' in content_type
        is_api = 'application/json' in content_type
        
        # 根据响应类型设置缓存
        if is_api:
            # API响应:使用短期缓存
            self.memory_cache.set(cache_key, response, 30)  # 30秒
            self.redis_cache.set(cache_key, response, 60)   # 1分钟
        
        elif is_html:
            # HTML响应:使用中期缓存
            self.memory_cache.set(cache_key, response, self.memory_timeout)
            self.disk_cache.set(cache_key, response, self.disk_timeout)
            
            # 静态页面可以缓存更长时间
            if 'static-page' in request.GET:
                self.redis_cache.set(cache_key, response, self.redis_timeout)
        
        else:
            # 其他类型响应:常规缓存策略
            self.memory_cache.set(cache_key, response, self.memory_timeout)
            self.disk_cache.set(cache_key, response, self.disk_timeout)
            self.redis_cache.set(cache_key, response, self.redis_timeout)
        
        return response
    
    def get_cache_key(self, request):
        """生成缓存键"""
        key = get_cache_key(request)
        if not key:
            # 备选方法:生成自定义缓存键
            key_parts = [
                request.path,
                request.META.get('QUERY_STRING', ''),
                request.COOKIES.get('session_variant', ''),  # 用于A/B测试
            ]
            key = hashlib.md5(''.join(key_parts).encode()).hexdigest()
        
        return key
    
    def should_cache_response(self, request, response):
        """决定是否应该缓存响应"""
        # 实现缓存决策逻辑...
        return True

8. 调试与监控中间件

8.1 SQL查询监控中间件

创建SQL查询监控中间件:

import time
import json
import logging
from django.db import connection
from django.conf import settings

logger = logging.getLogger('sql_performance')

class SQLMonitoringMiddleware:
    def __init__(self, get_response):
        self.get_response = get_response
        self.slow_query_threshold = getattr(settings, 'SLOW_QUERY_THRESHOLD', 0.1)  # 秒
    
    def __call__(self, request):
        # 只在DEBUG模式或明确启用时启用SQL监控
        if not (settings.DEBUG or getattr(settings, 'SQL_MONITORING_ENABLED', False)):
            return self.get_response(request)
        
        # 记录起始查询数
        start_queries = len(connection.queries)
        start_time = time.time()
        
        # 处理请求
        response = self.get_response(request)
        
        # 计算查询统计信息
        end_time = time.time()
        duration = end_time - start_time
        
        # 获取请求期间执行的查询
        queries = connection.queries[start_queries:]
        num_queries = len(queries)
        
        if num_queries > 0:
            # 计算查询总时间
            query_time = sum(float(q['time']) for q in queries)
            
            # 找出慢查询
            slow_queries = [q for q in queries if float(q['time']) > self.slow_query_threshold]
            
            # 记录统计信息
            logger.info(
                f"Request: {request.method} {request.path} | "
                f"Duration: {duration:.4f}s | "
                f"Queries: {num_queries} | "
                f"Query Time: {query_time:.4f}s | "
                f"Slow Queries: {len(slow_queries)}"
            )
            
            # 记录慢查询详情
            for i, query in enumerate(slow_queries):
                logger.warning(
                    f"Slow Query #{i+1} ({query['time']}s): {query['sql']}"
                )
            
            # 在响应中添加SQL统计信息(仅在DEBUG模式)
            if settings.DEBUG and request.headers.get('X-Requested-With') == 'XMLHttpRequest':
                # 对于AJAX请求,返回JSON响应中添加调试信息
                if 'application/json' in response.get('Content-Type', ''):
                    try:
                        data = json.loads(response.content.decode('utf-8'))
                        
                        # 添加调试信息
                        if isinstance(data, dict):
                            data['__debug'] = {
                                'duration': duration,
                                'queries': num_queries,
                                'query_time': query_time,
                                'slow_queries': len(slow_queries)
                            }
                            
                            # 更新响应内容
                            response.content = json.dumps(data).encode('utf-8')
                            response['Content-Length'] = str(len(response.content))
                    except (json.JSONDecodeError, UnicodeDecodeError

## 4. 中间件应用场景与实战解析

### 4.1 性能监控中间件

监控请求执行时间是优化应用性能的第一步。这个中间件可以帮助你识别慢请求:

```python
import time
import logging
from django.conf import settings

logger = logging.getLogger('performance')

class PerformanceMonitorMiddleware:
    def __init__(self, get_response):
        self.get_response = get_response
        # 设置阈值,超过该时间(ms)的请求将被记录
        self.threshold = getattr(settings, 'SLOW_REQUEST_THRESHOLD', 500)
    
    def __call__(self, request):
        start_time = time.time()
        response = self.get_response(request)
        duration = (time.time() - start_time) * 1000  # 转换为毫秒
        
        if duration > self.threshold:
            logger.warning(
                f'慢请求: {request.method} {request.path} 耗时 {duration:.2f}ms'
            )
            
        return response

在生产环境中,你可以配合ELK或Prometheus等监控系统,实现性能数据可视化和告警。

4.2 IP限流中间件

防止恶意请求和DOS攻击的简单限流实现:

from django.core.cache import cache
from django.http import HttpResponse
import time

class RateLimitMiddleware:
    def __init__(self, get_response):
        self.get_response = get_response
        self.rate_limit = 60  # 每分钟最大请求数
        self.window = 60  # 时间窗口(秒)
    
    def __call__(self, request):
        # 获取客户端IP
        ip = self.get_client_ip(request)
        cache_key = f'rate_limit:{ip}'
        
        # 获取当前请求计数
        requests = cache.get(cache_key, [])
        now = time.time()
        
        # 仅保留时间窗口内的请求记录
        requests = [t for t in requests if now - t < self.window]
        
        # 检查是否超过限制
        if len(requests) >= self.rate_limit:
            return HttpResponse('请求频率过高,请稍后再试', status=429)
        
        # 添加当前请求时间
        requests.append(now)
        cache.set(cache_key, requests, self.window)
        
        # 继续处理请求
        return self.get_response(request)
    
    def get_client_ip(self, request):
        x_forwarded_for = request.META.get('HTTP_X_FORWARDED_FOR')
        if x_forwarded_for:
            ip = x_forwarded_for.split(',')[0]
        else:
            ip = request.META.get('REMOTE_ADDR')
        return ip

4.3 跨域资源共享(CORS)中间件

实现自定义CORS策略(而不是使用django-cors-headers):

class CorsMiddleware:
    def __init__(self, get_response):
        self.get_response = get_response
        
    def __call__(self, request):
        response = self.get_response(request)
        
        # 添加CORS头信息
        response["Access-Control-Allow-Origin"] = "https://trusted-site.com"
        response["Access-Control-Allow-Methods"] = "GET, POST, OPTIONS"
        response["Access-Control-Allow-Headers"] = "Content-Type, Authorization"
        response["Access-Control-Allow-Credentials"] = "true"
        
        return response
    
    def process_view(self, request, view_func, view_args, view_kwargs):
        # 处理预检请求
        if request.method == "OPTIONS":
            response = HttpResponse()
            response["Access-Control-Allow-Origin"] = "https://trusted-site.com"
            response["Access-Control-Allow-Methods"] = "GET, POST, OPTIONS"
            response["Access-Control-Allow-Headers"] = "Content-Type, Authorization"
            response["Access-Control-Allow-Credentials"] = "true"
            response["Access-Control-Max-Age"] = "86400"  # 24小时
            return response
        return None

5. 中间件性能与优化

中间件在每个请求中都会执行,因此它们的性能直接影响整个应用的响应时间。

5.1 性能优化策略

  1. 减少处理逻辑:中间件应专注于其核心功能,避免复杂计算
  2. 使用缓存:频繁访问的数据应当缓存
  3. 延迟导入:在__init__方法中导入依赖,而不是模块级别
  4. 条件执行:通过URL路径或请求类型过滤,只在必要时处理
class ConditionalMiddleware:
    def __init__(self, get_response):
        self.get_response = get_response
        # 延迟导入依赖
        import re
        self.re = re
        # 定义白名单
        self.url_patterns = [
            self.re.compile(r'^/api/'),
            self.re.compile(r'^/admin/'),
        ]
        
    def __call__(self, request):
        # 条件执行
        if self.should_process(request):
            # 处理逻辑
            pass
        
        return self.get_response(request)
    
    def should_process(self, request):
        # 检查是否匹配白名单路径
        path = request.path
        return any(pattern.match(path) for pattern in self.url_patterns)

6. 中间件测试策略

中间件测试是确保它们正常工作的关键。以下是测试中间件的最佳实践:

6.1 单元测试

from django.test import TestCase, RequestFactory
from django.http import HttpResponse
from myapp.middleware import PerformanceMonitorMiddleware

class PerformanceMiddlewareTest(TestCase):
    def setUp(self):
        self.factory = RequestFactory()
        
        # 创建一个模拟的get_response函数
        def get_response(request):
            response = HttpResponse("测试响应")
            return response
            
        self.middleware = PerformanceMonitorMiddleware(get_response)
    
    def test_slow_request_is_logged(self):
        # 创建请求
        request = self.factory.get('/test/')
        
        # 模拟慢响应
        import time
        original_time = time.time
        
        try:
            # 模拟时间流逝
            mock_time_values = [100.0, 101.0]  # 模拟1秒(1000ms)的请求
            time.time = lambda: mock_time_values.pop(0)
            
            with self.assertLogs('performance', level='WARNING') as cm:
                response = self.middleware(request)
                self.assertIn('慢请求', cm.output[0])
                
        finally:
            # 恢复原始函数
            time.time = original_time

6.2 集成测试

from django.test import TestCase, Client
from django.urls import reverse

class RateLimitMiddlewareIntegrationTest(TestCase):
    def setUp(self):
        self.client = Client()
        
    def test_rate_limit_exceeded(self):
        url = reverse('api-endpoint')
        
        # 发送超过限制的请求
        for _ in range(61):  # 假设限制为60
            response = self.client.get(url)
            
        # 验证最后一个请求被限制
        self.assertEqual(response.status_code, 429)
        self.assertIn('请求频率过高', response.content.decode())

7. 与内置中间件协同工作

Django有许多内置中间件,如何与它们协同工作是中间件开发的重要方面。

7.1 中间件顺序

settings.py中,中间件按照定义顺序执行:

MIDDLEWARE = [
    'django.middleware.security.SecurityMiddleware',
    'django.contrib.sessions.middleware.SessionMiddleware',
    'myproject.middleware.custom_middleware.CustomMiddleware',  # 自定义中间件
    'django.middleware.common.CommonMiddleware',
    'django.middleware.csrf.CsrfViewMiddleware',
    'django.contrib.auth.middleware.AuthenticationMiddleware',
    'django.contrib.messages.middleware.MessageMiddleware',
    'django.middleware.clickjacking.XFrameOptionsMiddleware',
]

7.2 与内置中间件交互

class CustomAuthMiddleware:
    def __init__(self, get_response):
        self.get_response = get_response
        
    def __call__(self, request):
        # 需要在SessionMiddleware和AuthenticationMiddleware之后
        # 才能访问request.user和request.session
        if not request.user.is_authenticated:
            # 自定义认证逻辑
            pass
            
        return self.get_response(request)

8. 总结与最佳实践

开发Django中间件时,牢记以下最佳实践:

  1. 职责单一:每个中间件只负责一个明确的功能
  2. 性能优先:中间件执行频率高,必须高效
  3. 错误处理:妥善处理异常,不影响请求流程
  4. 测试覆盖:编写全面的测试确保正常工作
  5. 文档完善:为团队成员提供清晰的使用指南
  6. 配置灵活:通过settings提供可配置选项
  7. 兼容性考虑:特别是与内置中间件的交互
  8. 版本迁移:确保在Django版本升级时平滑过渡

Django中间件是构建健壮、可维护Web应用的强大工具。通过本文介绍的模式和实践,你可以充分发挥中间件的潜力,构建出更加安全、高效、易于维护的Django应用。

在下一篇文章中,我们将深入探讨Django的模板系统,带你掌握如何构建灵活且高效的前端展示层。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Is code

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值