Django REST Framework 之限流、版本、解析器源码解析(超详细)

53 篇文章 0 订阅

限流

可以对接口访问的频次进行限制,以减轻服务器压力。

自定义限流思路:
实现机制:自定义一个字典,key 为唯一标识,可以是 ip 地址、用户名、用户id 等,value 为一个列表,存储时间,如限制十秒内访问三次:

{"ip": [16:20:20,16:20:30,16:20:40]}
思路:
	如果用户没访问过就创建改用户的访问记录并记录时间
	如果用户已经访问过,获取用户访问历史,并且判断最后一次访问时间是否小于当前时间减去10秒
	如果小于说明是满足间隔时间,那么就把最后一条数据删除
	限制判断10秒内访问三次,判断长度是否小于三,
	如果小于三说明说明是可访问,需要在第0条插入访问的时间

自定义限流代码实现:
需要借助 drf 的基础限流类:BaseThrottle 并在里面重写 allow_request,这条要记清楚,不重写会报错,还有一个 wait 方法,返回动态的秒数。

VISIT_RECORD = {} # 一般是放缓存里面
class MyThrottle(BaseThrottle):
    def __init__(self):
        self.history = None

    def allow_request(self, request, view):
        remote_addr = request.META.get('REMOTE_ADDR')
        current_time = time.time()
        if remote_addr not in VISIT_RECORD:
            VISIT_RECORD[remote_addr] = current_time
            return True

        history = VISIT_RECORD.get(remote_addr)
        self.history = history
        while history and history[-1] < current_time - 10:
            history.pop()

        if len(history) < 3:
            history.insert(0, current_time)
            return True

    def wait(self):
        current_time = time.time()
        return 10 - (current_time - self.history[-1])

上述是自己定义的限流逻辑,实际上 drf 内部已经实现了限流功能,只需要导入就能使用:

SimpleRateThrottle 源码解读:简单速率限制

class SimpleRateThrottle(BaseThrottle):
    """
    A simple cache implementation, that only requires `.get_cache_key()`
    to be overridden.

    The rate (requests / seconds) is set by a `rate` attribute on the View
    class.  The attribute is a string of the form 'number_of_requests/period'.

    Period should be one of: ('s', 'sec', 'm', 'min', 'h', 'hour', 'd', 'day')

    Previous request information used for throttling is stored in the cache.
    """
    cache = default_cache
    timer = time.time
    cache_format = 'throttle_%(scope)s_%(ident)s'
    scope = None
    THROTTLE_RATES = api_settings.DEFAULT_THROTTLE_RATES

    def __init__(self):
    	# 先执行 __init__, 获取 rate 方法。
        if not getattr(self, 'rate', None):
        	# 返回 scope 的 value 值,如:"3/m"
            self.rate = self.get_rate()
        
        self.num_requests, self.duration = self.parse_rate(self.rate)
	
	# 在继承该类时要重写 get_cache_key 方法
    def get_cache_key(self, request, view):
        """
        Should return a unique cache-key which can be used for throttling.
        Must be overridden.

        May return `None` if the request should not be throttled.
        """
        raise NotImplementedError('.get_cache_key() must be overridden')

    def get_rate(self):
        """
        Determine the string representation of the allowed request rate.
        """
        # 必须配置 scope ,不然这里抛出异常
        if not getattr(self, 'scope', None):
            msg = ("You must set either `.scope` or `.rate` for '%s' throttle" %
                   self.__class__.__name__)
            raise ImproperlyConfigured(msg)

        try:
        	# api_settings.DEFAULT_THROTTLE_RATES 配置文件中需要配置好 scope 的 key
            return self.THROTTLE_RATES[self.scope]
        except KeyError:
            msg = "No default throttle rate set for '%s' scope" % self.scope
            raise ImproperlyConfigured(msg)

    def parse_rate(self, rate):
        """
        Given the request rate string, return a two tuple of:
        <allowed number of requests>, <period of time in seconds>
        """
        if rate is None:
            return (None, None)
        # "3/m"
        num, period = rate.split('/')
        num_requests = int(num)
        duration = {'s': 1, 'm': 60, 'h': 3600, 'd': 86400}[period[0]]
        # 返回访问次数和间隔时间(3, 60)
        return (num_requests, duration)

    def allow_request(self, request, view):
        """
        Implement the check to see if the request should be throttled.

        On success calls `throttle_success`.
        On failure calls `throttle_failure`.
        """
        if self.rate is None:
            return True

        self.key = self.get_cache_key(request, view)
        if self.key is None:
            return True
		
		# django 自带的缓存
        self.history = self.cache.get(self.key, [])
        # time.time
        self.now = self.timer()

        # Drop any requests from the history which have now passed the
        # throttle duration
        # 和上述我们自己实现的逻辑一样
        # 判断最近一次访问时间是否小于当前时间-间隔时间
        # 如果小于说明满足要去,就删除最后一条
        while self.history and self.history[-1] <= self.now - self.duration:
            self.history.pop()
        
        # 访问历史列表长度大于设置的次数就返回错误 false
        if len(self.history) >= self.num_requests:
            return self.throttle_failure()
        return self.throttle_success()

    def throttle_success(self):
        """
        Inserts the current request's timestamp along with the key
        into the cache.
        """
        # 第0条加入最新时间并放到缓存中
        self.history.insert(0, self.now)
        self.cache.set(self.key, self.history, self.duration)
        return True

    def throttle_failure(self):
        """
        Called when a request to the API has failed due to throttling.
        """
        return False

    def wait(self):
        """
        Returns the recommended next request time in seconds.
        """
        # 动态计算时间,间隔时间- (当前时间 - 最后一次访问的时间)
        if self.history:
            remaining_duration = self.duration - (self.now - self.history[-1])
        else:
            remaining_duration = self.duration
		
		# 计算请求次数
        available_requests = self.num_requests - len(self.history) + 1
        if available_requests <= 0:
            return None

        return remaining_duration / float(available_requests)

自定义限流类:

class MySimpleRateThrottle(SimpleRateThrottle):
    scope = "xxoo" # 配置文件中需要配置
	
	# 需要重新
    def get_cache_key(self, request, view):
    	# 唯一标识
        return self.get_ident(request)
        
	# 父类中 SimpleRateThrottle 已经实现了获取唯一标识的方法,获取ip
   	def get_ident(self, request):
       """
       Identify the machine making the request by parsing HTTP_X_FORWARDED_FOR
       if present and number of proxies is > 0. If not use all of
       HTTP_X_FORWARDED_FOR if it is available, if not use REMOTE_ADDR.
       """
       xff = request.META.get('HTTP_X_FORWARDED_FOR')
       remote_addr = request.META.get('REMOTE_ADDR')
       num_proxies = api_settings.NUM_PROXIES

       if num_proxies is not None:
           if num_proxies == 0 or xff is None:
               return remote_addr
           addrs = xff.split(',')
           client_addr = addrs[-min(num_proxies, len(addrs))]
           return client_addr.strip()

       return ''.join(xff.split()) if xff else remote_addr
	

配置文件:

REST_FRAMEWORK = {
    "DEFAULT_THROTTLE_RATES": {
        "xxoo": "3/m"
    },
}

APIView 中需要加上 throttle_classes:

class UserInfo(APIView):
    throttle_classes = [MySimpleRateThrottle, ]

效果:
在这里插入图片描述

内置限流类:

匿名限流类:针对未登录(匿名)用户的限流控制类

class AnonRateThrottle(SimpleRateThrottle):
    """
    Limits the rate of API calls that may be made by a anonymous users.

    The IP address of the request will be used as the unique cache key.
    """
    scope = 'anon'

    def get_cache_key(self, request, view):
        if request.user.is_authenticated:
            return None  # Only throttle unauthenticated requests.

        return self.cache_format % {
            'scope': self.scope,
            'ident': self.get_ident(request)
        }

认证用户限流类:针对登录(认证)用户的限流控制类

class UserRateThrottle(SimpleRateThrottle):
    """
    Limits the rate of API calls that may be made by a given user.

    The user id will be used as a unique cache key if the user is
    authenticated.  For anonymous requests, the IP address of the request will
    be used.
    """
    scope = 'user'

    def get_cache_key(self, request, view):
        if request.user.is_authenticated:
            ident = request.user.pk
        else:
            ident = self.get_ident(request)

        return self.cache_format % {
            'scope': self.scope,
            'ident': ident
        }

统一限流类:针对登录(认证)用户和匿名用户

class ScopedRateThrottle(SimpleRateThrottle):
    """
    Limits the rate of API calls by different amounts for various parts of
    the API.  Any view that has the `throttle_scope` property set will be
    throttled.  The unique cache key will be generated by concatenating the
    user id of the request, and the scope of the view being accessed.
    """
    scope_attr = 'throttle_scope'

    def __init__(self):
        # Override the usual SimpleRateThrottle, because we can't determine
        # the rate until called by the view.
        pass

    def allow_request(self, request, view):
        # We can only determine the scope once we're called by the view.
        self.scope = getattr(view, self.scope_attr, None)

        # If a view does not have a `throttle_scope` always allow the request
        if not self.scope:
            return True

        # Determine the allowed request rate as we normally would during
        # the `__init__` call.
        self.rate = self.get_rate()
        self.num_requests, self.duration = self.parse_rate(self.rate)

        # We can now proceed as normal.
        return super().allow_request(request, view)

    def get_cache_key(self, request, view):
        """
        If `view.throttle_scope` is not set, don't apply this throttle.

        Otherwise generate the unique cache key by concatenating the user id
        with the '.throttle_scope` property of the view.
        """
        if request.user.is_authenticated:
            ident = request.user.pk
        else:
            ident = self.get_ident(request)

        return self.cache_format % {
            'scope': self.scope,
            'ident': ident
        }

配置文件:

REST_FRAMEWORK = {
    'DEFAULT_THROTTLE_CLASSES': (
        # 针对未登录(匿名)用户的限流控制类
        'rest_framework.throttling.AnonRateThrottle',
        # 针对登录(认证)用户的限流控制类
        'rest_framework.throttling.UserRateThrottle'
    ),
    # 指定限流频次
    'DEFAULT_THROTTLE_RATES': {
        # 认证用户的限流频次
        'user': '5/minute',
        # 匿名用户的限流频次
        'anon': '3/minute',
    },
}

针对范围的配置文件稍微不一样:

REST_FRAMEWORK = {
    # 针对某一范围
    'DEFAULT_THROTTLE_CLASSES': (
        'rest_framework.throttling.ScopedRateThrottle',
    ),

    # 指定限流频次选择项
    'DEFAULT_THROTTLE_RATES': {
        'list': '3/m',
        'get': '5/m'
    },
}

视图中:

class ListView(APIView):
    throttle_scope = 'list'
    ...

class DetailView(APIView):
    throttle_scope = 'list'
    ...

class GetView(APIView):
    throttle_scope = 'get'
    ...

版本

系统都有版本,这是因为在迭代中的一种标记,记录过程。在常见的后端中一般会再 url 中上加上版本,如:/api/v1/users,这种概念在 restful api 中被提出,个人理解是和系统迭代挂钩。
在 drf 中也有版本的概念,估计是为了兼容 restful api 这种概念吧,实际上个人感觉是不需要,在业务开发时一般都自己定义了,不过也看下 drf 中版本的玩法吧,drf 中分两种,一种是通过参数获取,另一种是配置在 url 中。

QueryParameterVersioning:参数传递
class QueryParameterVersioning(BaseVersioning):
    """
    GET /something/?version=0.1 HTTP/1.1 通过路径传参获取版本,配置文件可配置
    Host: example.com
    Accept: application/json
    """
    invalid_version_message = _('Invalid version in query parameter.')

    def determine_version(self, request, *args, **kwargs):
   		# 获取参数默认是 version,如果获取不到使用默认的版本,都可配置
        version = request.query_params.get(self.version_param, self.default_version)
        # 允许的版本号
        if not self.is_allowed_version(version):
        	# 不在允许的版本号中直接抛出异常
            raise exceptions.NotFound(self.invalid_version_message)
        return version

配置文件:

REST_FRAMEWORK = {
    "DEFAULT_VERSION": "v1",  # 默认版本
    "ALLOWED_VERSIONS": ["v1", "v2"], # 允许的版本
    "VERSION_PARAM": "version", # 版本参数
}

视图中:

class UserInfo(APIView):
	# 这里注意了,不要用列表
    versioning_class = QueryParameterVersioning
    
    def get(self, request):
        version = request._request.GET.get("version")
        print(version)
        # 也可以通过 request.version 获取版本号,内部获取之后会赋值给 version 放到 request 中
        print(request.version)
        return Response()

效果:
在这里插入图片描述

URLPathVersioning:路径中(推荐使用)
class URLPathVersioning(BaseVersioning):
    """
    To the client this is the same style as `NamespaceVersioning`.
    The difference is in the backend - this implementation uses
    Django's URL keyword arguments to determine the version.

    An example URL conf for two views that accept two different versions.

    urlpatterns = [
        url(r'^(?P<version>[v1|v2]+)/users/$', users_list, name='users-list'),
        url(r'^(?P<version>[v1|v2]+)/users/(?P<pk>[0-9]+)/$', users_detail, name='users-detail')
    ]

    GET /1.0/something/ HTTP/1.1
    Host: example.com
    Accept: application/json
    """

项目 urls.py 中:

from django.contrib import admin
from django.conf.urls import url, include

urlpatterns = [
    url('admin/', admin.site.urls),
    url('api/', include("djcelerytest.urls"), name="xxoo"),
]

app urls.py 中:

from django.conf.urls import url
from djcelerytest.views import test, UserInfo

urlpatterns = [
    url('test/', test),
    url(r'^(?P<version>[v1|v2]+)/info/', UserInfo.as_view()),
]

视图中:

class UserInfo(APIView):
    versioning_class = URLPathVersioning

    def get(self, request, *args, **kwargs):
		# 内部会获取版本会赋值给 version
        print(request.version)
        return Response()

版本源码解读

initial 中在执行认证、权限之前会先获取版本号。
在这里插入图片描述
determine_version:

def determine_version(self, request, *args, **kwargs):
    """
    If versioning is being used, then determine any API version for the
    incoming request. Returns a two-tuple of (version, versioning_scheme)
    """
    if self.versioning_class is None:
        return (None, None)
       
    # 自己定义的 versioning_class = URLPathVersioning  或者配置文件中的
    scheme = self.versioning_class()
    # 执行 URLPathVersioning 的 determine_version 方法
    return (scheme.determine_version(request, *args, **kwargs), scheme)

determine_version :

def determine_version(self, request, *args, **kwargs):
	# version_param:version
	# default_version:配合文件中写的 v1
    version = kwargs.get(self.version_param, self.default_version)
    if version is None:
        version = self.default_version
	
	# 允许的版本号
    if not self.is_allowed_version(version):
        raise exceptions.NotFound(self.invalid_version_message)
    
    # 返会版本
    return version

赋值给 request:

# Determine the API version, if versioning is in use.
# 返回版本号,scheme :反向解析时用针对 QueryParameterVersioning
version, scheme = self.determine_version(request, *args, **kwargs)
# 赋值给 reques
request.version, request.versioning_scheme = version, scheme

scheme :反向解析时用针对 QueryParameterVersioning
urls.py:

urlpatterns = [
	url(r'^(?P<version>[v1|v2]+)/info/$', UserInfo.as_view(), name="xxoo"),
]
class UserInfo(APIView):
    versioning_class = QueryParameterVersioning

    def get(self, request, *args, **kwargs):
        print(request.version)
        print(request.versioning_scheme.reverse(viewname="xxoo", request=request))
        return Response()
        
	# v1
	# http://127.0.0.1:8000/api/v1/info/
其它版本类:
# 基于命名空间做
class NamespaceVersioning(BaseVersioning):
    """
    To the client this is the same style as `URLPathVersioning`.
    The difference is in the backend - this implementation uses
    Django's URL namespaces to determine the version.

    An example URL conf that is namespaced into two separate versions

    # users/urls.py
    urlpatterns = [
        url(r'^/users/$', users_list, name='users-list'),
        url(r'^/users/(?P<pk>[0-9]+)/$', users_detail, name='users-detail')
    ]
    """

# 基于子域名
class HostNameVersioning(BaseVersioning):
    """
    GET /something/ HTTP/1.1
    Host: v1.example.com
    Accept: application/json
    """

# 基于请求头做
class AcceptHeaderVersioning(BaseVersioning):
    """
    GET /something/ HTTP/1.1
    Host: example.com
    Accept: application/json; version=1.0
    """

解析器

我们在 request.data 或是 request.query_params 中的数据实际都是 drf 给我们做了一层解析,根据请求头的不同获取实例化不同的数据类型,下面直接看源码:

# 上传文件的头
if self.content_type == 'multipart/form-data':
    if hasattr(self, '_body'):
        # Use already read data
        data = BytesIO(self._body)
    else:
        data = self
    try:
        self._post, self._files = self.parse_file_upload(self.META, data)
    except MultiPartParserError:
        # An error occurred while parsing POST data. Since when
        # formatting the error the request handler might access
        # self.POST, set self._post and self._file to prevent
        # attempts to parse POST data again.
        self._mark_post_parse_error()
        raise
# 请求头是 application/x-www-form-urlencoded  POST中才有数据,去 body 中拿数据
# 数据格式:name=xx&age=18
elif self.content_type == 'application/x-www-form-urlencoded':
    self._post, self._files = QueryDict(self.body, encoding=self._encoding), MultiValueDict()
else:
    self._post, self._files = QueryDict(encoding=self._encoding), MultiValueDict()

DRF 中的解析器:

JSONParser: 只支持Content-Type = application/json

class JSONParser(BaseParser):
    """
    Parses JSON-serialized data.
    """
    media_type = 'application/json'   # 根据media_type 
    renderer_class = renderers.JSONRenderer
    strict = api_settings.STRICT_JSON

如果是其他类型会提示:

{
    "detail": "不支持请求中的媒体类型 “application/x-www-form-urlencoded”。"
}

FormParser: 只支持Content-Type = ‘application/x-www-form-urlencoded’

class FormParser(BaseParser):
    """
    Parser for form data.
    """
    media_type = 'application/x-www-form-urlencoded'

MultiPartParser: 只支持Content-Type = ‘multipart/form-data’ 文件上传

class MultiPartParser(BaseParser):
    """
    Parser for multipart form data, which may include file data.
    """
    media_type = 'multipart/form-data'

**FileUploadParser: ** media_type = ‘/’ 什么类型都可以

class FileUploadParser(BaseParser):
    """
    Parser for file upload data.
    """
    media_type = '*/*'
    errors = {
        'unhandled': 'FileUpload parse error - none of upload handlers can handle the stream',
        'no_filename': 'Missing filename. Request should include a Content-Disposition header with a filename parameter.',
    }

视图中需要指定解析器:

class UserInfo(APIView):
    parser_classes = [JSONParser]

也可以全局配置:

REST_FRAMEWORK = {
    "DEFAULT_PARSER_CLASSES": ["rest_framework.parsers.JSONParser"],
}

dispatch 在封装 request 时也会把 parser_classes 封装进去:

def initialize_request(self, request, *args, **kwargs):
    """
    Returns the initial request object.
    """
    parser_context = self.get_parser_context(request)

    return Request(
        request,
        # 解析器
        parsers=self.get_parsers(),
        # 认证
        authenticators=self.get_authenticators(),
        negotiator=self.get_content_negotiator(),
        parser_context=parser_context
    )

Django REST Framework 全局配置文件

疑问?那么多配置,如果自己没有学到,或者压根不知道,那么默认是从哪里来的?是不是都要配置?其实大可不必,DRF 默认也会给我们配置一些。从下面导入:

from rest_framework.settings import api_settings

默认配置:

DEFAULTS = {
    # Base API policies
    'DEFAULT_RENDERER_CLASSES': [
        'rest_framework.renderers.JSONRenderer',
        'rest_framework.renderers.BrowsableAPIRenderer',
    ],
    'DEFAULT_PARSER_CLASSES': [
        'rest_framework.parsers.JSONParser',
        'rest_framework.parsers.FormParser',
        'rest_framework.parsers.MultiPartParser'
    ],
    'DEFAULT_AUTHENTICATION_CLASSES': [
        'rest_framework.authentication.SessionAuthentication',
        'rest_framework.authentication.BasicAuthentication'
    ],
    'DEFAULT_PERMISSION_CLASSES': [
        'rest_framework.permissions.AllowAny',
    ],
    'DEFAULT_THROTTLE_CLASSES': [],
    'DEFAULT_CONTENT_NEGOTIATION_CLASS': 'rest_framework.negotiation.DefaultContentNegotiation',
    'DEFAULT_METADATA_CLASS': 'rest_framework.metadata.SimpleMetadata',
    'DEFAULT_VERSIONING_CLASS': None,

    # Generic view behavior
    'DEFAULT_PAGINATION_CLASS': None,
    'DEFAULT_FILTER_BACKENDS': [],

    # Schema
    'DEFAULT_SCHEMA_CLASS': 'rest_framework.schemas.openapi.AutoSchema',

    # Throttling
    'DEFAULT_THROTTLE_RATES': {
        'user': None,
        'anon': None,
    },
    'NUM_PROXIES': None,

    # Pagination
    'PAGE_SIZE': None,

    # Filtering
    'SEARCH_PARAM': 'search',
    'ORDERING_PARAM': 'ordering',

    # Versioning
    'DEFAULT_VERSION': None,
    'ALLOWED_VERSIONS': None,
    'VERSION_PARAM': 'version',

    # Authentication
    'UNAUTHENTICATED_USER': 'django.contrib.auth.models.AnonymousUser',
    'UNAUTHENTICATED_TOKEN': None,

    # View configuration
    'VIEW_NAME_FUNCTION': 'rest_framework.views.get_view_name',
    'VIEW_DESCRIPTION_FUNCTION': 'rest_framework.views.get_view_description',

    # Exception handling
    'EXCEPTION_HANDLER': 'rest_framework.views.exception_handler',
    'NON_FIELD_ERRORS_KEY': 'non_field_errors',

    # Testing
    'TEST_REQUEST_RENDERER_CLASSES': [
        'rest_framework.renderers.MultiPartRenderer',
        'rest_framework.renderers.JSONRenderer'
    ],
    'TEST_REQUEST_DEFAULT_FORMAT': 'multipart',

    # Hyperlink settings
    'URL_FORMAT_OVERRIDE': 'format',
    'FORMAT_SUFFIX_KWARG': 'format',
    'URL_FIELD_NAME': 'url',

    # Input and output formats
    'DATE_FORMAT': ISO_8601,
    'DATE_INPUT_FORMATS': [ISO_8601],

    'DATETIME_FORMAT': ISO_8601,
    'DATETIME_INPUT_FORMATS': [ISO_8601],

    'TIME_FORMAT': ISO_8601,
    'TIME_INPUT_FORMATS': [ISO_8601],

    # Encoding
    'UNICODE_JSON': True,
    'COMPACT_JSON': True,
    'STRICT_JSON': True,
    'COERCE_DECIMAL_TO_STRING': True,
    'UPLOADED_FILES_USE_URL': True,

    # Browseable API
    'HTML_SELECT_CUTOFF': 1000,
    'HTML_SELECT_CUTOFF_TEXT': "More than {count} items...",

    # Schemas
    'SCHEMA_COERCE_PATH_PK': True,
    'SCHEMA_COERCE_METHOD_NAMES': {
        'retrieve': 'read',
        'destroy': 'delete'
    },
}

总结

限流:限流工作中一般用的很少,针对大并发一般是想办法优化,不会截流
版本:drf 中版本没必要用,url 中自己写比使用内置的可读性高很多了
解析器:使用默认的就行,一般工作中也不会刻意配置

参考文献

drf 官网:https://www.django-rest-framework.org/api-guide/settings/
drf 中文文档:https://q1mi.github.io/Django-REST-framework-documentation/

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值