我们在做服务器API
接口时,肯定会考虑该接口访问控制,例如某IP
请求频率限制、某注册用户限制、未注册用户限制(后面的2中用户限制,是在时间频率限制基础上完善的),这样会起到数据保护、服务器减压作用
今天我们就来从源码的角度,看下django-rest-framework
是如何实现,通过IP
进行有效的频率控制的,我还是罗列出如何配置使用SimpleRateThrottle
类,通过一个demo形式,结合源码理解更加容易
示例代码
我们如果使用django-rest-framework
自带的IP
频率控制,首先需要在settings.py
也就是django项目下的配置,至于如何加载的,后续的源码会一一展开
settings.py配置DEFAULT_THROTTLE_RATES
REST_FRAMEWORK={
"DEFAULT_THROTTLE_RATES":{
"xxxx":"3/m",
},
}
我们继承SimpleRateThrottle
实现的子类
class VisitThrottlee(SimpleRateThrottle):
scope = "xxxx"
def get_cache_key(self, request, view):
return self.get_ident(request)
视图views
from rest_framework.viewsets import ModelViewSet
class BookViewSet(ModelViewSet):
throttle_classes = [VisitThrottlee]
queryset = Book.objects.all()
serializer_class = BookSerializers
router路由配置
url(r"books/$", views.BookViewSet.as_view({"get": "list", "post": "create"}), name="book_list"),
测试如下:
源码分析
例子简单,那我们就结合例子来分析下源码,看看SimpleRateThrottle
是如何实现IP
频率访问控制的?
对于rest-framework
的API
接口设计,始终离不开View
及其子类APIView
、GenericAPIView
、ModelViewSet
、及其mixins
模块,本例示例是通过ModelViewSet
来实现的,前几篇笔记已经多次的分析了诸如这些类的关系、执行流程,今天就简单分析,重点来了解SimpleRateThrottle
的IP频率的逻辑代码
当我们在url
通过as_view
配置restful
请求时,该as_view
方法会返回一个ViewSetMixin类方法view
,当我们通过某URL
发起restful
请求时,诸如get\post\put\delete
,这里我们就本例的http://127.0.0.1:8000/books/
(get请求)来说明,会触发view
方法,构造我们的BookViewSet
对象,也就是我们编写的ModelViewSet
的子类,然后遍历as_view
的action
,
以我们的示例url举例,会遍历出list
的属性或者方法,然后将其赋值给get
请求,然后执行ViewSetMixin
中的dispatch
方法,由于在ViewSetMixin
找寻不到,就去多重继承类APIView
中找寻,然后rest-framework
封装为自己Request
,然后进行认证校验、权限校验、IP频率控制,验证通过反射找到get
方法,然后执行get
方法,最后封装response
返回即可
BaseThrottle
以上就是逻辑,终于一顿铺垫,我们接下来终于要分析IP频率控制
的逻辑部分,也就是SimpleRateThrottle
,该类继承BaseThrottle
class SimpleRateThrottle(BaseThrottle)
class BaseThrottle(object):
def allow_request(self, request, view):
raise NotImplementedError('.allow_request() must be overridden')
def get_ident(self, request):
xff = request.META.get('HTTP_X_FORWARDED_FOR')
remote_addr = request.META.get('REMOTE_ADDR')
num_proxies = api_settings.NUM_PROXIES
if num_proxies is not None:
if num_proxies == 0 or xff is None:
return remote_addr
addrs = xff.split(',')
client_addr = addrs[-min(num_proxies, len(addrs))]
return client_addr.strip()
return ''.join(xff.split()) if xff else remote_addr
def wait(self):
return None
在BaseThrottle
中通过get_ident
获取请求的IP
值,对REMOTE_ADDR
进行了一些优化封装,allow_request
是必须让子类实现的方法,该方法在APIView
中会使用,当allow_request
返回True
,就代表可以验证通过,当返回False
,就验证失败,在频率规则内控制其请求访问,
以下是APIView
中的代码
def check_throttles(self, request):
for throttle in self.get_throttles():
if not throttle.allow_request(request, self):
self.throttled(request, throttle.wait())
BaseThrottle
类中的wait
方法也需要我们自定义重写,IP频率控制,一般是在多少时间段,允许其访问多少次,该方法作用是访问失败时,返回告诉IP客户,还剩多少时间可以允许再次访问API接口操作数据
SimpleRateThrottle
接下来我们看下BaseThrottle
的子类SimpleRateThrottle
,我先列出该类的方法,然后逐一分析
__init__
get_cache_key
get_rate
parse_rate
allow_request
throttle_success
throttle_failure
wait
我们先看在初始化__init__
方法做了哪些操作?
if not getattr(self, 'rate', None):
self.rate = self.get_rate()
self.num_requests, self.duration = self.parse_rate(self.rate)
首先查看本类及其父子类,有没有rate
属性,我们查询给出结果是None
,继而会通过方法get_rate
获取rate
,以下是get_rate
方法
def get_rate(self):
if not getattr(self, 'scope', None):
msg = ("You must set either `.scope` or `.rate` for '%s' throttle" %
self.__class__.__name__)
raise ImproperlyConfigured(msg)
try:
return self.THROTTLE_RATES[self.scope]
except KeyError:
msg = "No default throttle rate set for '%s' scope" % self.scope
raise ImproperlyConfigured(msg)
get_rate
方法首先判断not getattr(self, 'scope', None)
父子类有没有定义scope
属性如果没有定义就会抛出异常,如果定义了就去执行return self.THROTTLE_RATES[self.scope]
我们开始在定义的类中定义了一个变量属性scope = "xxxx"
,所以self.scope
即是"xxxx"
那我们看下self.THROTTLE_RATES是什么数据类型?跟踪到如下代码
THROTTLE_RATES = api_settings.DEFAULT_THROTTLE_RATES
它在settings
模块中,注意这个不是项目下的配置,而是rest-framework
里面的settings模块是一个字典数据的一个DEFAULT_THROTTLE_RATES
键对应的值,是一个字典格式,如下是格式
'DEFAULT_THROTTLE_RATES': {
'user': None,
'anon': None,
}
我们在回头看下get_rate
方法中的return self.THROTTLE_RATES[self.scope]
,貌似不是我们要的数据,该字典没有scope
键,那我们从哪里去查找呢?
我们在看看SimpleRateThrottle
里面的属性中的THROTTLE_RATES = api_settings.DEFAULT_THROTTLE_RATES
class SimpleRateThrottle(BaseThrottle):
cache = default_cache
timer = time.time
cache_format = 'throttle_%(scope)s_%(ident)s'
scope = None
THROTTLE_RATES = api_settings.DEFAULT_THROTTLE_RATES
当我们在throttling.py
模块中,导入from rest_framework.settings import api_settings
api_settings = APISettings(None, DEFAULTS, IMPORT_STRINGS)
当我们在SimpleRateThrottle
中进行api_settings.DEFAULT_THROTTLE_RATES
,会去APISettings
中查找是否有DEFAULT_THROTTLE_RATES
该属性,我们翻阅结果发现没有该属性,就会调用__getattr__
这是python标准库的用法,并将DEFAULT_THROTTLE_RATES
作为参数传递进来,这里api_settings的构造方法会初始化接受2个参数,DEFAULTS
和IMPORT_STRINGS
,
在__getattr__
继续执行val = self.user_settings[attr]
,豁然开朗,拿到getattr(settings, 'REST_FRAMEWORK', {})
,也就是项目配置下的
REST_FRAMEWORK={
"DEFAULT_THROTTLE_RATES":{
"xxxx":"3/m",
},
}
然后通过__getattr__
中的attr
参数也就是DEFAULT_THROTTLE_RATES
,通过方法val = self.user_settings[attr]
拿到"xxxx":"3/m"
一顿铺垫,继续回过头来看SimpleRateThrottle
类中的get_rate
方法,终于可以从return self.THROTTLE_RATES[self.scope]
拿到"3/m"
,
然后继续看SimpleRateThrottle
类中的__init__
函数,通过self.parse_rate(self.rate)
,方法如下:
def parse_rate(self, rate):
if rate is None:
return (None, None)
num, period = rate.split('/')
num_requests = int(num)
duration = {'s': 1, 'm': 60, 'h': 3600, 'd': 86400}[period[0]]
return (num_requests, duration)
逻辑简单,这里就不去介绍了,最后拿到一个元组数据(3, 60),意思就是60秒内可以访问3次!!!
接下来就看allow_request
逻辑部分了
def allow_request(self, request, view):
if self.rate is None:
return True
self.key = self.get_cache_key(request, view)
if self.key is None:
return True
self.history = self.cache.get(self.key, [])
self.now = self.timer()
while self.history and self.history[-1] <= self.now - self.duration:
self.history.pop()
if len(self.history) >= self.num_requests:
return self.throttle_failure()
return self.throttle_success()
通过self.rate is None
如果没有配置,就说明不进行频率控制,通过self.get_cache_key(request, view)
,进行请求的IP
缓存,get_cache_key
是一个必须让子类实现的类,
def get_cache_key(self, request, view):
raise NotImplementedError('.get_cache_key() must be overridden')
我们在看下我们自定义的频率校验类
class VisitThrottlee(SimpleRateThrottle):
scope = "xxxx"
def get_cache_key(self, request, view):
return self.get_ident(request)
get_cache_key
会得到请求的IP
,然后通过self.history = self.cache.get(self.key, [])
拿到该IP的请求记录,它的数据格式如下:
[1523440775.529766, 1523440775.1320794, 1523440755.5490992]
所以key
是IP
,history
是其访问记录(列表),key-history
是一对键值对
接下来就是本片博客的最重要的地方了
while self.history and self.history[-1] <= self.now - self.duration:
self.history.pop()
if len(self.history) >= self.num_requests:
return self.throttle_failure()
return self.throttle_success()
这就是如何实现在多少秒访问多少次的逻辑,本例的实现是60秒内访问3次
逻辑是当某请求的IP
,没有访问记录,就直接self.throttle_success
,当某IP
有访问记录,并且访问记录列表中的最早的一次IP
访问,<=
时间(这个时间是当前时间-规定的多少秒),
我们的实例中式60秒3次,某IP
访问记录最多是3条记录,就pop
掉列表最早的那个访问记录,添加进来最新的访问记录,当某IP的反问记录,>=
规定的访问次数(某个时间段内)就抛出异常
接下来我们看下访问成功后做了什么操作?
def throttle_success(self):
self.history.insert(0, self.now)
self.cache.set(self.key, self.history, self.duration)
return True
访问成功后,就在某IP访问记录头部,添加访问记录,然后更新缓存记录
访问失败时候,做了什么处理?直接返回False
了,也就是allow_request
方法返回了False
def throttle_failure(self):
return False
这些方法哪里调用的呢?我们最开始已经提及了,我们在回到APIView
类中
def check_throttles(self, request):
"""
Check if request should be throttled.
Raises an appropriate exception if the request is throttled.
"""
for throttle in self.get_throttles():
if not throttle.allow_request(request, self):
self.throttled(request, throttle.wait())
如果返回False
,就执行throttled
方法,注意这里把我们的wait
返回值传递给该方法
我们看下该方法
def wait(self):
"""
Returns the recommended next request time in seconds.
"""
if self.history:
remaining_duration = self.duration - (self.now - self.history[-1])
else:
remaining_duration = self.duration
available_requests = self.num_requests - len(self.history) + 1
if available_requests <= 0:
return None
return remaining_duration / float(available_requests)
逻辑是当被限制IP频率访问后,会告知还有多少时间才可以正常访问,然后将返回值传递给throttled
接下来看下throttled
方法
def throttled(self, request, wait):
raise exceptions.Throttled(wait)
抛出了一个异常
class Throttled(APIException):
status_code = status.HTTP_429_TOO_MANY_REQUESTS
default_detail = _('Request was throttled.')
extra_detail_singular = 'Expected available in {wait} second.'
extra_detail_plural = 'Expected available in {wait} seconds.'
default_code = 'throttled'
def __init__(self, wait=None, detail=None, code=None):
if detail is None:
detail = force_text(self.default_detail)
if wait is not None:
wait = math.ceil(wait)
detail = ' '.join((
detail,
force_text(ungettext(self.extra_detail_singular.format(wait=wait),
self.extra_detail_plural.format(wait=wait),
wait))))
self.wait = wait
super(Throttled, self).__init__(detail, code)
类继承于APIException
class APIException(Exception):
"""
Base class for REST framework exceptions.
Subclasses should provide `.status_code` and `.default_detail` properties.
"""
status_code = status.HTTP_500_INTERNAL_SERVER_ERROR
default_detail = _('A server error occurred.')
default_code = 'error'
def __init__(self, detail=None, code=None):
if detail is None:
detail = self.default_detail
if code is None:
code = self.default_code
self.detail = _get_error_details(detail, code)
def __str__(self):
return six.text_type(self.detail)
def get_codes(self):
"""
Return only the code part of the error details.
Eg. {"name": ["required"]}
"""
return _get_codes(self.detail)
def get_full_details(self):
"""
Return both the message & code parts of the error details.
Eg. {"name": [{"message": "This field is required.", "code": "required"}]}
"""
return _get_full_details(self.detail)
最后看下封装的返回信息格式
{
"detail": "Request was throttled. Expected available in 48 seconds."
}
自己方式实现
当然我们也可以使用如下的方式,实现简易的IP频率访问控制
import time
VISITED_RECORD={}
class VisitThrottle(BaseThrottle):
def __init__(self):
self.history=None
def allow_request(self,request,view):
visit_ip=self.get_ident(request)
print(visit_ip)
ctime=time.time()
while VISITED_RECORD and VISITED_RECORD[-1] <= ctime - 60:
self.history.pop()
VISITED_RECORD[visit_ip].insert(0, ctime)
return True
return False
def wait(self):
import time
ctime = time.time()
return 60 - (ctime - self.history[-1])
或者如下的方式:
from rest_framework.throttling import BaseThrottle,SimpleRateThrottle
import time
VISITED_RECORD={}
class VisitThrottle(BaseThrottle):
def __init__(self):
self.history=None
def allow_request(self,request,view):
print("ident",self.get_ident(request))
#visit_ip=request.META.get('REMOTE_ADDR')
visit_ip=self.get_ident(request)
print(visit_ip)
ctime=time.time()
#第一次访问请求
if visit_ip not in VISITED_RECORD:
VISITED_RECORD[visit_ip]=[ctime]
return True
# self.history:当前请求IP的记录列表
self.history = VISITED_RECORD[visit_ip]
print(self.history)
# 第2,3次访问请求
if len(VISITED_RECORD[visit_ip])<3:
VISITED_RECORD[visit_ip].insert(0,ctime)
return True
if ctime-VISITED_RECORD[visit_ip][-1]>60:
VISITED_RECORD[visit_ip].pop()
VISITED_RECORD[visit_ip].insert(0,ctime)
print("ok")
return True
return False
def wait(self):
import time
ctime = time.time()
return 60 - (ctime - self.history[-1])