限流
可以对接口访问的频次进行限制,以减轻服务器压力。
自定义限流思路:
实现机制:自定义一个字典,key 为唯一标识,可以是 ip 地址、用户名、用户id 等,value 为一个列表,存储时间,如限制十秒内访问三次:
{"ip": [16:20:20,16:20:30,16:20:40]}
思路:
如果用户没访问过就创建改用户的访问记录并记录时间
如果用户已经访问过,获取用户访问历史,并且判断最后一次访问时间是否小于当前时间减去10秒
如果小于说明是满足间隔时间,那么就把最后一条数据删除
限制判断10秒内访问三次,判断长度是否小于三,
如果小于三说明说明是可访问,需要在第0条插入访问的时间
自定义限流代码实现:
需要借助 drf 的基础限流类:BaseThrottle 并在里面重写 allow_request,这条要记清楚,不重写会报错,还有一个 wait 方法,返回动态的秒数。
VISIT_RECORD = {} # 一般是放缓存里面
class MyThrottle(BaseThrottle):
def __init__(self):
self.history = None
def allow_request(self, request, view):
remote_addr = request.META.get('REMOTE_ADDR')
current_time = time.time()
if remote_addr not in VISIT_RECORD:
VISIT_RECORD[remote_addr] = current_time
return True
history = VISIT_RECORD.get(remote_addr)
self.history = history
while history and history[-1] < current_time - 10:
history.pop()
if len(history) < 3:
history.insert(0, current_time)
return True
def wait(self):
current_time = time.time()
return 10 - (current_time - self.history[-1])
上述是自己定义的限流逻辑,实际上 drf 内部已经实现了限流功能,只需要导入就能使用:
SimpleRateThrottle 源码解读:简单速率限制
class SimpleRateThrottle(BaseThrottle):
"""
A simple cache implementation, that only requires `.get_cache_key()`
to be overridden.
The rate (requests / seconds) is set by a `rate` attribute on the View
class. The attribute is a string of the form 'number_of_requests/period'.
Period should be one of: ('s', 'sec', 'm', 'min', 'h', 'hour', 'd', 'day')
Previous request information used for throttling is stored in the cache.
"""
cache = default_cache
timer = time.time
cache_format = 'throttle_%(scope)s_%(ident)s'
scope = None
THROTTLE_RATES = api_settings.DEFAULT_THROTTLE_RATES
def __init__(self):
# 先执行 __init__, 获取 rate 方法。
if not getattr(self, 'rate', None):
# 返回 scope 的 value 值,如:"3/m"
self.rate = self.get_rate()
self.num_requests, self.duration = self.parse_rate(self.rate)
# 在继承该类时要重写 get_cache_key 方法
def get_cache_key(self, request, view):
"""
Should return a unique cache-key which can be used for throttling.
Must be overridden.
May return `None` if the request should not be throttled.
"""
raise NotImplementedError('.get_cache_key() must be overridden')
def get_rate(self):
"""
Determine the string representation of the allowed request rate.
"""
# 必须配置 scope ,不然这里抛出异常
if not getattr(self, 'scope', None):
msg = ("You must set either `.scope` or `.rate` for '%s' throttle" %
self.__class__.__name__)
raise ImproperlyConfigured(msg)
try:
# api_settings.DEFAULT_THROTTLE_RATES 配置文件中需要配置好 scope 的 key
return self.THROTTLE_RATES[self.scope]
except KeyError:
msg = "No default throttle rate set for '%s' scope" % self.scope
raise ImproperlyConfigured(msg)
def parse_rate(self, rate):
"""
Given the request rate string, return a two tuple of:
<allowed number of requests>, <period of time in seconds>
"""
if rate is None:
return (None, None)
# "3/m"
num, period = rate.split('/')
num_requests = int(num)
duration = {'s': 1, 'm': 60, 'h': 3600, 'd': 86400}[period[0]]
# 返回访问次数和间隔时间(3, 60)
return (num_requests, duration)
def allow_request(self, request, view):
"""
Implement the check to see if the request should be throttled.
On success calls `throttle_success`.
On failure calls `throttle_failure`.
"""
if self.rate is None:
return True
self.key = self.get_cache_key(request, view)
if self.key is None:
return True
# django 自带的缓存
self.history = self.cache.get(self.key, [])
# time.time
self.now = self.timer()
# Drop any requests from the history which have now passed the
# throttle duration
# 和上述我们自己实现的逻辑一样
# 判断最近一次访问时间是否小于当前时间-间隔时间
# 如果小于说明满足要去,就删除最后一条
while self.history and self.history[-1] <= self.now - self.duration:
self.history.pop()
# 访问历史列表长度大于设置的次数就返回错误 false
if len(self.history) >= self.num_requests:
return self.throttle_failure()
return self.throttle_success()
def throttle_success(self):
"""
Inserts the current request's timestamp along with the key
into the cache.
"""
# 第0条加入最新时间并放到缓存中
self.history.insert(0, self.now)
self.cache.set(self.key, self.history, self.duration)
return True
def throttle_failure(self):
"""
Called when a request to the API has failed due to throttling.
"""
return False
def wait(self):
"""
Returns the recommended next request time in seconds.
"""
# 动态计算时间,间隔时间- (当前时间 - 最后一次访问的时间)
if self.history:
remaining_duration = self.duration - (self.now - self.history[-1])
else:
remaining_duration = self.duration
# 计算请求次数
available_requests = self.num_requests - len(self.history) + 1
if available_requests <= 0:
return None
return remaining_duration / float(available_requests)
自定义限流类:
class MySimpleRateThrottle(SimpleRateThrottle):
scope = "xxoo" # 配置文件中需要配置
# 需要重新
def get_cache_key(self, request, view):
# 唯一标识
return self.get_ident(request)
# 父类中 SimpleRateThrottle 已经实现了获取唯一标识的方法,获取ip
def get_ident(self, request):
"""
Identify the machine making the request by parsing HTTP_X_FORWARDED_FOR
if present and number of proxies is > 0. If not use all of
HTTP_X_FORWARDED_FOR if it is available, if not use REMOTE_ADDR.
"""
xff = request.META.get('HTTP_X_FORWARDED_FOR')
remote_addr = request.META.get('REMOTE_ADDR')
num_proxies = api_settings.NUM_PROXIES
if num_proxies is not None:
if num_proxies == 0 or xff is None:
return remote_addr
addrs = xff.split(',')
client_addr = addrs[-min(num_proxies, len(addrs))]
return client_addr.strip()
return ''.join(xff.split()) if xff else remote_addr
配置文件:
REST_FRAMEWORK = {
"DEFAULT_THROTTLE_RATES": {
"xxoo": "3/m"
},
}
APIView 中需要加上 throttle_classes:
class UserInfo(APIView):
throttle_classes = [MySimpleRateThrottle, ]
效果:
内置限流类:
匿名限流类:针对未登录(匿名)用户的限流控制类
class AnonRateThrottle(SimpleRateThrottle):
"""
Limits the rate of API calls that may be made by a anonymous users.
The IP address of the request will be used as the unique cache key.
"""
scope = 'anon'
def get_cache_key(self, request, view):
if request.user.is_authenticated:
return None # Only throttle unauthenticated requests.
return self.cache_format % {
'scope': self.scope,
'ident': self.get_ident(request)
}
认证用户限流类:针对登录(认证)用户的限流控制类
class UserRateThrottle(SimpleRateThrottle):
"""
Limits the rate of API calls that may be made by a given user.
The user id will be used as a unique cache key if the user is
authenticated. For anonymous requests, the IP address of the request will
be used.
"""
scope = 'user'
def get_cache_key(self, request, view):
if request.user.is_authenticated:
ident = request.user.pk
else:
ident = self.get_ident(request)
return self.cache_format % {
'scope': self.scope,
'ident': ident
}
统一限流类:针对登录(认证)用户和匿名用户
class ScopedRateThrottle(SimpleRateThrottle):
"""
Limits the rate of API calls by different amounts for various parts of
the API. Any view that has the `throttle_scope` property set will be
throttled. The unique cache key will be generated by concatenating the
user id of the request, and the scope of the view being accessed.
"""
scope_attr = 'throttle_scope'
def __init__(self):
# Override the usual SimpleRateThrottle, because we can't determine
# the rate until called by the view.
pass
def allow_request(self, request, view):
# We can only determine the scope once we're called by the view.
self.scope = getattr(view, self.scope_attr, None)
# If a view does not have a `throttle_scope` always allow the request
if not self.scope:
return True
# Determine the allowed request rate as we normally would during
# the `__init__` call.
self.rate = self.get_rate()
self.num_requests, self.duration = self.parse_rate(self.rate)
# We can now proceed as normal.
return super().allow_request(request, view)
def get_cache_key(self, request, view):
"""
If `view.throttle_scope` is not set, don't apply this throttle.
Otherwise generate the unique cache key by concatenating the user id
with the '.throttle_scope` property of the view.
"""
if request.user.is_authenticated:
ident = request.user.pk
else:
ident = self.get_ident(request)
return self.cache_format % {
'scope': self.scope,
'ident': ident
}
配置文件:
REST_FRAMEWORK = {
'DEFAULT_THROTTLE_CLASSES': (
# 针对未登录(匿名)用户的限流控制类
'rest_framework.throttling.AnonRateThrottle',
# 针对登录(认证)用户的限流控制类
'rest_framework.throttling.UserRateThrottle'
),
# 指定限流频次
'DEFAULT_THROTTLE_RATES': {
# 认证用户的限流频次
'user': '5/minute',
# 匿名用户的限流频次
'anon': '3/minute',
},
}
针对范围的配置文件稍微不一样:
REST_FRAMEWORK = {
# 针对某一范围
'DEFAULT_THROTTLE_CLASSES': (
'rest_framework.throttling.ScopedRateThrottle',
),
# 指定限流频次选择项
'DEFAULT_THROTTLE_RATES': {
'list': '3/m',
'get': '5/m'
},
}
视图中:
class ListView(APIView):
throttle_scope = 'list'
...
class DetailView(APIView):
throttle_scope = 'list'
...
class GetView(APIView):
throttle_scope = 'get'
...
版本
系统都有版本,这是因为在迭代中的一种标记,记录过程。在常见的后端中一般会再 url 中上加上版本,如:/api/v1/users,这种概念在 restful api 中被提出,个人理解是和系统迭代挂钩。
在 drf 中也有版本的概念,估计是为了兼容 restful api 这种概念吧,实际上个人感觉是不需要,在业务开发时一般都自己定义了,不过也看下 drf 中版本的玩法吧,drf 中分两种,一种是通过参数获取,另一种是配置在 url 中。
QueryParameterVersioning:参数传递
class QueryParameterVersioning(BaseVersioning):
"""
GET /something/?version=0.1 HTTP/1.1 通过路径传参获取版本,配置文件可配置
Host: example.com
Accept: application/json
"""
invalid_version_message = _('Invalid version in query parameter.')
def determine_version(self, request, *args, **kwargs):
# 获取参数默认是 version,如果获取不到使用默认的版本,都可配置
version = request.query_params.get(self.version_param, self.default_version)
# 允许的版本号
if not self.is_allowed_version(version):
# 不在允许的版本号中直接抛出异常
raise exceptions.NotFound(self.invalid_version_message)
return version
配置文件:
REST_FRAMEWORK = {
"DEFAULT_VERSION": "v1", # 默认版本
"ALLOWED_VERSIONS": ["v1", "v2"], # 允许的版本
"VERSION_PARAM": "version", # 版本参数
}
视图中:
class UserInfo(APIView):
# 这里注意了,不要用列表
versioning_class = QueryParameterVersioning
def get(self, request):
version = request._request.GET.get("version")
print(version)
# 也可以通过 request.version 获取版本号,内部获取之后会赋值给 version 放到 request 中
print(request.version)
return Response()
效果:
URLPathVersioning:路径中(推荐使用)
class URLPathVersioning(BaseVersioning):
"""
To the client this is the same style as `NamespaceVersioning`.
The difference is in the backend - this implementation uses
Django's URL keyword arguments to determine the version.
An example URL conf for two views that accept two different versions.
urlpatterns = [
url(r'^(?P<version>[v1|v2]+)/users/$', users_list, name='users-list'),
url(r'^(?P<version>[v1|v2]+)/users/(?P<pk>[0-9]+)/$', users_detail, name='users-detail')
]
GET /1.0/something/ HTTP/1.1
Host: example.com
Accept: application/json
"""
项目 urls.py 中:
from django.contrib import admin
from django.conf.urls import url, include
urlpatterns = [
url('admin/', admin.site.urls),
url('api/', include("djcelerytest.urls"), name="xxoo"),
]
app urls.py 中:
from django.conf.urls import url
from djcelerytest.views import test, UserInfo
urlpatterns = [
url('test/', test),
url(r'^(?P<version>[v1|v2]+)/info/', UserInfo.as_view()),
]
视图中:
class UserInfo(APIView):
versioning_class = URLPathVersioning
def get(self, request, *args, **kwargs):
# 内部会获取版本会赋值给 version
print(request.version)
return Response()
版本源码解读
initial 中在执行认证、权限之前会先获取版本号。
determine_version:
def determine_version(self, request, *args, **kwargs):
"""
If versioning is being used, then determine any API version for the
incoming request. Returns a two-tuple of (version, versioning_scheme)
"""
if self.versioning_class is None:
return (None, None)
# 自己定义的 versioning_class = URLPathVersioning 或者配置文件中的
scheme = self.versioning_class()
# 执行 URLPathVersioning 的 determine_version 方法
return (scheme.determine_version(request, *args, **kwargs), scheme)
determine_version :
def determine_version(self, request, *args, **kwargs):
# version_param:version
# default_version:配合文件中写的 v1
version = kwargs.get(self.version_param, self.default_version)
if version is None:
version = self.default_version
# 允许的版本号
if not self.is_allowed_version(version):
raise exceptions.NotFound(self.invalid_version_message)
# 返会版本
return version
赋值给 request:
# Determine the API version, if versioning is in use.
# 返回版本号,scheme :反向解析时用针对 QueryParameterVersioning
version, scheme = self.determine_version(request, *args, **kwargs)
# 赋值给 reques
request.version, request.versioning_scheme = version, scheme
scheme :反向解析时用针对 QueryParameterVersioning
urls.py:
urlpatterns = [
url(r'^(?P<version>[v1|v2]+)/info/$', UserInfo.as_view(), name="xxoo"),
]
class UserInfo(APIView):
versioning_class = QueryParameterVersioning
def get(self, request, *args, **kwargs):
print(request.version)
print(request.versioning_scheme.reverse(viewname="xxoo", request=request))
return Response()
# v1
# http://127.0.0.1:8000/api/v1/info/
其它版本类:
# 基于命名空间做
class NamespaceVersioning(BaseVersioning):
"""
To the client this is the same style as `URLPathVersioning`.
The difference is in the backend - this implementation uses
Django's URL namespaces to determine the version.
An example URL conf that is namespaced into two separate versions
# users/urls.py
urlpatterns = [
url(r'^/users/$', users_list, name='users-list'),
url(r'^/users/(?P<pk>[0-9]+)/$', users_detail, name='users-detail')
]
"""
# 基于子域名
class HostNameVersioning(BaseVersioning):
"""
GET /something/ HTTP/1.1
Host: v1.example.com
Accept: application/json
"""
# 基于请求头做
class AcceptHeaderVersioning(BaseVersioning):
"""
GET /something/ HTTP/1.1
Host: example.com
Accept: application/json; version=1.0
"""
解析器
我们在 request.data 或是 request.query_params 中的数据实际都是 drf 给我们做了一层解析,根据请求头的不同获取实例化不同的数据类型,下面直接看源码:
# 上传文件的头
if self.content_type == 'multipart/form-data':
if hasattr(self, '_body'):
# Use already read data
data = BytesIO(self._body)
else:
data = self
try:
self._post, self._files = self.parse_file_upload(self.META, data)
except MultiPartParserError:
# An error occurred while parsing POST data. Since when
# formatting the error the request handler might access
# self.POST, set self._post and self._file to prevent
# attempts to parse POST data again.
self._mark_post_parse_error()
raise
# 请求头是 application/x-www-form-urlencoded POST中才有数据,去 body 中拿数据
# 数据格式:name=xx&age=18
elif self.content_type == 'application/x-www-form-urlencoded':
self._post, self._files = QueryDict(self.body, encoding=self._encoding), MultiValueDict()
else:
self._post, self._files = QueryDict(encoding=self._encoding), MultiValueDict()
DRF 中的解析器:
JSONParser: 只支持Content-Type = application/json
class JSONParser(BaseParser):
"""
Parses JSON-serialized data.
"""
media_type = 'application/json' # 根据media_type
renderer_class = renderers.JSONRenderer
strict = api_settings.STRICT_JSON
如果是其他类型会提示:
{
"detail": "不支持请求中的媒体类型 “application/x-www-form-urlencoded”。"
}
FormParser: 只支持Content-Type = ‘application/x-www-form-urlencoded’
class FormParser(BaseParser):
"""
Parser for form data.
"""
media_type = 'application/x-www-form-urlencoded'
MultiPartParser: 只支持Content-Type = ‘multipart/form-data’ 文件上传
class MultiPartParser(BaseParser):
"""
Parser for multipart form data, which may include file data.
"""
media_type = 'multipart/form-data'
**FileUploadParser: ** media_type = ‘/’ 什么类型都可以
class FileUploadParser(BaseParser):
"""
Parser for file upload data.
"""
media_type = '*/*'
errors = {
'unhandled': 'FileUpload parse error - none of upload handlers can handle the stream',
'no_filename': 'Missing filename. Request should include a Content-Disposition header with a filename parameter.',
}
视图中需要指定解析器:
class UserInfo(APIView):
parser_classes = [JSONParser]
也可以全局配置:
REST_FRAMEWORK = {
"DEFAULT_PARSER_CLASSES": ["rest_framework.parsers.JSONParser"],
}
dispatch 在封装 request 时也会把 parser_classes 封装进去:
def initialize_request(self, request, *args, **kwargs):
"""
Returns the initial request object.
"""
parser_context = self.get_parser_context(request)
return Request(
request,
# 解析器
parsers=self.get_parsers(),
# 认证
authenticators=self.get_authenticators(),
negotiator=self.get_content_negotiator(),
parser_context=parser_context
)
Django REST Framework 全局配置文件
疑问?那么多配置,如果自己没有学到,或者压根不知道,那么默认是从哪里来的?是不是都要配置?其实大可不必,DRF 默认也会给我们配置一些。从下面导入:
from rest_framework.settings import api_settings
默认配置:
DEFAULTS = {
# Base API policies
'DEFAULT_RENDERER_CLASSES': [
'rest_framework.renderers.JSONRenderer',
'rest_framework.renderers.BrowsableAPIRenderer',
],
'DEFAULT_PARSER_CLASSES': [
'rest_framework.parsers.JSONParser',
'rest_framework.parsers.FormParser',
'rest_framework.parsers.MultiPartParser'
],
'DEFAULT_AUTHENTICATION_CLASSES': [
'rest_framework.authentication.SessionAuthentication',
'rest_framework.authentication.BasicAuthentication'
],
'DEFAULT_PERMISSION_CLASSES': [
'rest_framework.permissions.AllowAny',
],
'DEFAULT_THROTTLE_CLASSES': [],
'DEFAULT_CONTENT_NEGOTIATION_CLASS': 'rest_framework.negotiation.DefaultContentNegotiation',
'DEFAULT_METADATA_CLASS': 'rest_framework.metadata.SimpleMetadata',
'DEFAULT_VERSIONING_CLASS': None,
# Generic view behavior
'DEFAULT_PAGINATION_CLASS': None,
'DEFAULT_FILTER_BACKENDS': [],
# Schema
'DEFAULT_SCHEMA_CLASS': 'rest_framework.schemas.openapi.AutoSchema',
# Throttling
'DEFAULT_THROTTLE_RATES': {
'user': None,
'anon': None,
},
'NUM_PROXIES': None,
# Pagination
'PAGE_SIZE': None,
# Filtering
'SEARCH_PARAM': 'search',
'ORDERING_PARAM': 'ordering',
# Versioning
'DEFAULT_VERSION': None,
'ALLOWED_VERSIONS': None,
'VERSION_PARAM': 'version',
# Authentication
'UNAUTHENTICATED_USER': 'django.contrib.auth.models.AnonymousUser',
'UNAUTHENTICATED_TOKEN': None,
# View configuration
'VIEW_NAME_FUNCTION': 'rest_framework.views.get_view_name',
'VIEW_DESCRIPTION_FUNCTION': 'rest_framework.views.get_view_description',
# Exception handling
'EXCEPTION_HANDLER': 'rest_framework.views.exception_handler',
'NON_FIELD_ERRORS_KEY': 'non_field_errors',
# Testing
'TEST_REQUEST_RENDERER_CLASSES': [
'rest_framework.renderers.MultiPartRenderer',
'rest_framework.renderers.JSONRenderer'
],
'TEST_REQUEST_DEFAULT_FORMAT': 'multipart',
# Hyperlink settings
'URL_FORMAT_OVERRIDE': 'format',
'FORMAT_SUFFIX_KWARG': 'format',
'URL_FIELD_NAME': 'url',
# Input and output formats
'DATE_FORMAT': ISO_8601,
'DATE_INPUT_FORMATS': [ISO_8601],
'DATETIME_FORMAT': ISO_8601,
'DATETIME_INPUT_FORMATS': [ISO_8601],
'TIME_FORMAT': ISO_8601,
'TIME_INPUT_FORMATS': [ISO_8601],
# Encoding
'UNICODE_JSON': True,
'COMPACT_JSON': True,
'STRICT_JSON': True,
'COERCE_DECIMAL_TO_STRING': True,
'UPLOADED_FILES_USE_URL': True,
# Browseable API
'HTML_SELECT_CUTOFF': 1000,
'HTML_SELECT_CUTOFF_TEXT': "More than {count} items...",
# Schemas
'SCHEMA_COERCE_PATH_PK': True,
'SCHEMA_COERCE_METHOD_NAMES': {
'retrieve': 'read',
'destroy': 'delete'
},
}
总结
限流:限流工作中一般用的很少,针对大并发一般是想办法优化,不会截流
版本:drf 中版本没必要用,url 中自己写比使用内置的可读性高很多了
解析器:使用默认的就行,一般工作中也不会刻意配置
参考文献
drf 官网:https://www.django-rest-framework.org/api-guide/settings/
drf 中文文档:https://q1mi.github.io/Django-REST-framework-documentation/