django---SimpleRateThrottle请求频率源码分析(rest-framework)

我们在做服务器API接口时,肯定会考虑该接口访问控制,例如某IP请求频率限制、某注册用户限制、未注册用户限制(后面的2中用户限制,是在时间频率限制基础上完善的),这样会起到数据保护、服务器减压作用

今天我们就来从源码的角度,看下django-rest-framework是如何实现,通过IP进行有效的频率控制的,我还是罗列出如何配置使用SimpleRateThrottle类,通过一个demo形式,结合源码理解更加容易

示例代码

我们如果使用django-rest-framework自带的IP频率控制,首先需要在settings.py也就是django项目下的配置,至于如何加载的,后续的源码会一一展开

settings.py配置DEFAULT_THROTTLE_RATES


REST_FRAMEWORK={
"DEFAULT_THROTTLE_RATES":{
            "xxxx":"3/m",
        },
}

我们继承SimpleRateThrottle实现的子类

class VisitThrottlee(SimpleRateThrottle):
    scope = "xxxx"

    def get_cache_key(self, request, view):
        return self.get_ident(request)

视图views

from rest_framework.viewsets import  ModelViewSet
class BookViewSet(ModelViewSet):
    throttle_classes = [VisitThrottlee]
    queryset = Book.objects.all()
    serializer_class = BookSerializers

router路由配置

url(r"books/$", views.BookViewSet.as_view({"get": "list", "post": "create"}), name="book_list"),

测试如下:
这里写图片描述

源码分析

例子简单,那我们就结合例子来分析下源码,看看SimpleRateThrottle是如何实现IP频率访问控制的?
对于rest-frameworkAPI接口设计,始终离不开View及其子类APIViewGenericAPIViewModelViewSet、及其mixins模块,本例示例是通过ModelViewSet来实现的,前几篇笔记已经多次的分析了诸如这些类的关系、执行流程,今天就简单分析,重点来了解SimpleRateThrottle的IP频率的逻辑代码
这里写图片描述

当我们在url通过as_view配置restful请求时,该as_view方法会返回一个ViewSetMixin类方法view,当我们通过某URL发起restful请求时,诸如get\post\put\delete,这里我们就本例的http://127.0.0.1:8000/books/(get请求)来说明,会触发view方法,构造我们的BookViewSet对象,也就是我们编写的ModelViewSet的子类,然后遍历as_viewaction,
以我们的示例url举例,会遍历出list的属性或者方法,然后将其赋值给get请求,然后执行ViewSetMixin中的dispatch方法,由于在ViewSetMixin找寻不到,就去多重继承类APIView中找寻,然后rest-framework封装为自己Request,然后进行认证校验、权限校验、IP频率控制,验证通过反射找到get方法,然后执行get方法,最后封装response返回即可

BaseThrottle

以上就是逻辑,终于一顿铺垫,我们接下来终于要分析IP频率控制的逻辑部分,也就是SimpleRateThrottle,该类继承BaseThrottle

class SimpleRateThrottle(BaseThrottle)
class BaseThrottle(object):
    def allow_request(self, request, view):
        raise NotImplementedError('.allow_request() must be overridden')

    def get_ident(self, request):
        xff = request.META.get('HTTP_X_FORWARDED_FOR')
        remote_addr = request.META.get('REMOTE_ADDR')
        num_proxies = api_settings.NUM_PROXIES

        if num_proxies is not None:
            if num_proxies == 0 or xff is None:
                return remote_addr
            addrs = xff.split(',')
            client_addr = addrs[-min(num_proxies, len(addrs))]
            return client_addr.strip()

        return ''.join(xff.split()) if xff else remote_addr

    def wait(self):
        return None

BaseThrottle中通过get_ident获取请求的IP值,对REMOTE_ADDR进行了一些优化封装,allow_request是必须让子类实现的方法,该方法在APIView中会使用,当allow_request返回True,就代表可以验证通过,当返回False,就验证失败,在频率规则内控制其请求访问,
以下是APIView中的代码

    def check_throttles(self, request):
        for throttle in self.get_throttles():
            if not throttle.allow_request(request, self):
                self.throttled(request, throttle.wait())

BaseThrottle类中的wait方法也需要我们自定义重写,IP频率控制,一般是在多少时间段,允许其访问多少次,该方法作用是访问失败时,返回告诉IP客户,还剩多少时间可以允许再次访问API接口操作数据

SimpleRateThrottle

接下来我们看下BaseThrottle的子类SimpleRateThrottle,我先列出该类的方法,然后逐一分析

__init__
get_cache_key
get_rate
parse_rate
allow_request
throttle_success
throttle_failure
wait

我们先看在初始化__init__方法做了哪些操作?

        if not getattr(self, 'rate', None):
            self.rate = self.get_rate()
        self.num_requests, self.duration = self.parse_rate(self.rate)

首先查看本类及其父子类,有没有rate属性,我们查询给出结果是None,继而会通过方法get_rate获取rate,以下是get_rate方法

 def get_rate(self):
        if not getattr(self, 'scope', None):
            msg = ("You must set either `.scope` or `.rate` for '%s' throttle" %
                   self.__class__.__name__)
            raise ImproperlyConfigured(msg)
        try:
            return self.THROTTLE_RATES[self.scope]
        except KeyError:
            msg = "No default throttle rate set for '%s' scope" % self.scope
            raise ImproperlyConfigured(msg)

get_rate方法首先判断not getattr(self, 'scope', None)父子类有没有定义scope属性如果没有定义就会抛出异常,如果定义了就去执行return self.THROTTLE_RATES[self.scope]
我们开始在定义的类中定义了一个变量属性scope = "xxxx",所以self.scope即是"xxxx"那我们看下self.THROTTLE_RATES是什么数据类型?跟踪到如下代码

 THROTTLE_RATES = api_settings.DEFAULT_THROTTLE_RATES

它在settings模块中,注意这个不是项目下的配置,而是rest-framework里面的settings模块是一个字典数据的一个DEFAULT_THROTTLE_RATES键对应的值,是一个字典格式,如下是格式

'DEFAULT_THROTTLE_RATES': {
        'user': None,
        'anon': None,
    }

我们在回头看下get_rate方法中的return self.THROTTLE_RATES[self.scope],貌似不是我们要的数据,该字典没有scope键,那我们从哪里去查找呢?

我们在看看SimpleRateThrottle里面的属性中的THROTTLE_RATES = api_settings.DEFAULT_THROTTLE_RATES


class SimpleRateThrottle(BaseThrottle):
    cache = default_cache
    timer = time.time
    cache_format = 'throttle_%(scope)s_%(ident)s'
    scope = None
    THROTTLE_RATES = api_settings.DEFAULT_THROTTLE_RATES

当我们在throttling.py模块中,导入from rest_framework.settings import api_settings

api_settings = APISettings(None, DEFAULTS, IMPORT_STRINGS)

当我们在SimpleRateThrottle中进行api_settings.DEFAULT_THROTTLE_RATES,会去APISettings中查找是否有DEFAULT_THROTTLE_RATES该属性,我们翻阅结果发现没有该属性,就会调用__getattr__这是python标准库的用法,并将DEFAULT_THROTTLE_RATES作为参数传递进来,这里api_settings的构造方法会初始化接受2个参数,DEFAULTSIMPORT_STRINGS
__getattr__继续执行val = self.user_settings[attr],豁然开朗,拿到getattr(settings, 'REST_FRAMEWORK', {}),也就是项目配置下的

REST_FRAMEWORK={
  "DEFAULT_THROTTLE_RATES":{
            "xxxx":"3/m",
        },
}

然后通过__getattr__中的attr参数也就是DEFAULT_THROTTLE_RATES,通过方法val = self.user_settings[attr] 拿到"xxxx":"3/m"

一顿铺垫,继续回过头来看SimpleRateThrottle类中的get_rate方法,终于可以从return self.THROTTLE_RATES[self.scope]拿到"3/m"
然后继续看SimpleRateThrottle类中的__init__函数,通过self.parse_rate(self.rate),方法如下:

 def parse_rate(self, rate):
        if rate is None:
            return (None, None)
        num, period = rate.split('/')
        num_requests = int(num)
        duration = {'s': 1, 'm': 60, 'h': 3600, 'd': 86400}[period[0]]
        return (num_requests, duration)

逻辑简单,这里就不去介绍了,最后拿到一个元组数据(3, 60),意思就是60秒内可以访问3次!!!

接下来就看allow_request逻辑部分了

    def allow_request(self, request, view):
        if self.rate is None:
            return True

        self.key = self.get_cache_key(request, view)

        if self.key is None:
            return True
        self.history = self.cache.get(self.key, [])
        self.now = self.timer()
        while self.history and self.history[-1] <= self.now - self.duration:
            self.history.pop()
        if len(self.history) >= self.num_requests:
            return self.throttle_failure()
        return self.throttle_success()

通过self.rate is None如果没有配置,就说明不进行频率控制,通过self.get_cache_key(request, view),进行请求的IP缓存,get_cache_key是一个必须让子类实现的类,

 def get_cache_key(self, request, view):
        raise NotImplementedError('.get_cache_key() must be overridden')

我们在看下我们自定义的频率校验类

class VisitThrottlee(SimpleRateThrottle):
    scope = "xxxx"

    def get_cache_key(self, request, view):
        return self.get_ident(request)

get_cache_key会得到请求的IP,然后通过self.history = self.cache.get(self.key, [])拿到该IP的请求记录,它的数据格式如下:

[1523440775.529766, 1523440775.1320794, 1523440755.5490992]

所以keyIP,history是其访问记录(列表),key-history是一对键值对

接下来就是本片博客的最重要的地方了

  while self.history and self.history[-1] <= self.now - self.duration:
            self.history.pop()
        if len(self.history) >= self.num_requests:
            return self.throttle_failure()
        return self.throttle_success()

这就是如何实现在多少秒访问多少次的逻辑,本例的实现是60秒内访问3次
逻辑是当某请求的IP,没有访问记录,就直接self.throttle_success,当某IP有访问记录,并且访问记录列表中的最早的一次IP访问,<=时间(这个时间是当前时间-规定的多少秒),
我们的实例中式60秒3次,某IP访问记录最多是3条记录,就pop掉列表最早的那个访问记录,添加进来最新的访问记录,当某IP的反问记录,>= 规定的访问次数(某个时间段内)就抛出异常

接下来我们看下访问成功后做了什么操作?

 def throttle_success(self):
        self.history.insert(0, self.now)
        self.cache.set(self.key, self.history, self.duration)
        return True

访问成功后,就在某IP访问记录头部,添加访问记录,然后更新缓存记录
访问失败时候,做了什么处理?直接返回False了,也就是allow_request方法返回了False

    def throttle_failure(self):
        return False

这些方法哪里调用的呢?我们最开始已经提及了,我们在回到APIView类中

    def check_throttles(self, request):
        """
        Check if request should be throttled.
        Raises an appropriate exception if the request is throttled.
        """
        for throttle in self.get_throttles():
            if not throttle.allow_request(request, self):
                self.throttled(request, throttle.wait())

如果返回False,就执行throttled方法,注意这里把我们的wait返回值传递给该方法
我们看下该方法

 def wait(self):
        """
        Returns the recommended next request time in seconds.
        """
        if self.history:
            remaining_duration = self.duration - (self.now - self.history[-1])
        else:
            remaining_duration = self.duration

        available_requests = self.num_requests - len(self.history) + 1
        if available_requests <= 0:

            return None

        return remaining_duration / float(available_requests)

逻辑是当被限制IP频率访问后,会告知还有多少时间才可以正常访问,然后将返回值传递给throttled
接下来看下throttled方法

    def throttled(self, request, wait):
        raise exceptions.Throttled(wait)

抛出了一个异常

class Throttled(APIException):
    status_code = status.HTTP_429_TOO_MANY_REQUESTS
    default_detail = _('Request was throttled.')
    extra_detail_singular = 'Expected available in {wait} second.'
    extra_detail_plural = 'Expected available in {wait} seconds.'
    default_code = 'throttled'

    def __init__(self, wait=None, detail=None, code=None):
        if detail is None:
            detail = force_text(self.default_detail)
        if wait is not None:
            wait = math.ceil(wait)
            detail = ' '.join((
                detail,
                force_text(ungettext(self.extra_detail_singular.format(wait=wait),
                                     self.extra_detail_plural.format(wait=wait),
                                     wait))))
        self.wait = wait
        super(Throttled, self).__init__(detail, code)

类继承于APIException

class APIException(Exception):
    """
    Base class for REST framework exceptions.
    Subclasses should provide `.status_code` and `.default_detail` properties.
    """
    status_code = status.HTTP_500_INTERNAL_SERVER_ERROR
    default_detail = _('A server error occurred.')
    default_code = 'error'

    def __init__(self, detail=None, code=None):
        if detail is None:
            detail = self.default_detail
        if code is None:
            code = self.default_code

        self.detail = _get_error_details(detail, code)

    def __str__(self):
        return six.text_type(self.detail)

    def get_codes(self):
        """
        Return only the code part of the error details.

        Eg. {"name": ["required"]}
        """
        return _get_codes(self.detail)

    def get_full_details(self):
        """
        Return both the message & code parts of the error details.

        Eg. {"name": [{"message": "This field is required.", "code": "required"}]}
        """
        return _get_full_details(self.detail)

最后看下封装的返回信息格式

{
  "detail": "Request was throttled. Expected available in 48 seconds."
}

自己方式实现

当然我们也可以使用如下的方式,实现简易的IP频率访问控制

import time
VISITED_RECORD={}

class VisitThrottle(BaseThrottle):
    def __init__(self):
        self.history=None

    def allow_request(self,request,view):
        visit_ip=self.get_ident(request)
        print(visit_ip)
        ctime=time.time()

        while VISITED_RECORD and VISITED_RECORD[-1] <= ctime - 60:
            self.history.pop()
            VISITED_RECORD[visit_ip].insert(0, ctime)

            return True

        return False
    def wait(self):
        import time
        ctime = time.time()
        return 60 - (ctime - self.history[-1])

或者如下的方式:

from rest_framework.throttling import BaseThrottle,SimpleRateThrottle

import time
VISITED_RECORD={}

class VisitThrottle(BaseThrottle):
    def __init__(self):
        self.history=None

    def allow_request(self,request,view):
        print("ident",self.get_ident(request))
        #visit_ip=request.META.get('REMOTE_ADDR')
        visit_ip=self.get_ident(request)
        print(visit_ip)
        ctime=time.time()

        #第一次访问请求
        if visit_ip not in VISITED_RECORD:
            VISITED_RECORD[visit_ip]=[ctime]
            return True
        # self.history:当前请求IP的记录列表
        self.history = VISITED_RECORD[visit_ip]
        print(self.history)

        # 第2,3次访问请求
        if len(VISITED_RECORD[visit_ip])<3:
            VISITED_RECORD[visit_ip].insert(0,ctime)
            return True

        if ctime-VISITED_RECORD[visit_ip][-1]>60:
            VISITED_RECORD[visit_ip].pop()
            VISITED_RECORD[visit_ip].insert(0,ctime)
            print("ok")
            return True

        return False

    def wait(self):
        import time
        ctime = time.time()
        return 60 - (ctime - self.history[-1])
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值