DRF学习——限流组件

一、简介

本篇文章主要介绍drf限流组件的快速使用,并从源码角度分析drf限流组件的调用过程

二、快速使用

①在ext目录下定义throttle.py,其代码为:

# ext/throttle.py
from rest_framework.throttling import SimpleRateThrottle
from django.core.cache import cache as default_cache


class MyThrottle(SimpleRateThrottle):
    cache = default_cache
    scope = "xxx"
    # 每分钟最多只能访问5次
    THROTTLE_RATES = {"xxx": "5/m"}

    def get_cache_key(self, request, view):
        if request.user:
            ident = request.user.pk  # 用户ID
        else:
            # 获取请求用户IP(去request中找请求头)
            ident = self.get_ident(request)  

        # throttle_u # throttle_user_11.11.11.11ser_2

        return self.cache_format % {'scope': self.scope, 'ident': ident}

 ②在view视图中应用该组件即可,如:

# api/views.py
class LoginView(APIView):
    authentication_classes = []
    permission_classes = []
    # 应用自定义的限流组件
    throttle_classes = [MyThrottle, ]

    def post(self, request):
        # 1.接受用户名和密码
        print(request.data)
        user = request.data.get('username')
        pwd = request.data.get('password')

        # 2.数据库校验
        user_object = models.UserInfo.objects.filter(username=user, password=pwd).first()
        if not user_object:
            return Response({'code': 201, "msg": "用户名密码错误"})
        token = str(uuid.uuid4())
        user_object.token = token
        user_object.save()
        return Response({'code': 0, 'data': token})

③测试接口

访问login接口(前五次),可以正常返回 

访问login接口(第六次),收到频繁访问报错,成功限流

三、源码分析

①首先找到APIView类中的dispatch()方法,其源码如下:

    def dispatch(self, request, *args, **kwargs):
        """
        `.dispatch()` is pretty much the same as Django's regular dispatch,
        but with extra hooks for startup, finalize, and exception handling.
        """
        self.args = args
        self.kwargs = kwargs
        request = self.initialize_request(request, *args, **kwargs)
        self.request = request
        self.headers = self.default_response_headers  # deprecate?

        try:
            self.initial(request, *args, **kwargs)

            # Get the appropriate handler method
            if request.method.lower() in self.http_method_names:
                handler = getattr(self, request.method.lower(),
                                  self.http_method_not_allowed)
            else:
                handler = self.http_method_not_allowed

            response = handler(request, *args, **kwargs)

        except Exception as exc:
            response = self.handle_exception(exc)

        self.response = self.finalize_response(request, response, *args, **kwargs)
        return self.response

 在request封装完成之后,走到对象初始化

self.initial(request, *args, **kwargs)

 其源码如下:

    def initial(self, request, *args, **kwargs):
        """
        Runs anything that needs to occur prior to calling the method handler.
        """
        self.format_kwarg = self.get_format_suffix(**kwargs)

        # Perform content negotiation and store the accepted info on the request
        neg = self.perform_content_negotiation(request)
        request.accepted_renderer, request.accepted_media_type = neg

        # Determine the API version, if versioning is in use.
        version, scheme = self.determine_version(request, *args, **kwargs)
        request.version, request.versioning_scheme = version, scheme

        # Ensure that the incoming request is permitted
        self.perform_authentication(request)
        self.check_permissions(request)
        self.check_throttles(request)

前两者分别为认证组件初始化和权限组件初始化,第三个check_throttles便是限流组件初始化,其源码如下:

    def check_throttles(self, request):
        """
        Check if request should be throttled.
        Raises an appropriate exception if the request is throttled.
        """
        throttle_durations = []
        for throttle in self.get_throttles():
            if not throttle.allow_request(request, self):
                throttle_durations.append(throttle.wait())

        if throttle_durations:
            # Filter out `None` values which may happen in case of config / rate
            # changes, see #1438
            durations = [
                duration for duration in throttle_durations
                if duration is not None
            ]

            duration = max(durations, default=None)
            self.throttled(request, duration)

该段代码首先 定义了一个空列表throttle_durations ,如何调用了get_throttles方法,其源码为:

    def get_throttles(self):
        """
        Instantiates and returns the list of throttles that this view uses.
        """
        return [throttle() for throttle in self.throttle_classes]

根据throttle_classes返回throttle实例化对象,而throttle_classes定义在views视图中,代码如下:

class LoginView(APIView):
    authentication_classes = []
    permission_classes = []
    # throttle_classes定义在此
    throttle_classes = [MyThrottle, ]

    def post(self, request):
        # 1.接受用户名和密码
        print(request.data)
        user = request.data.get('username')
        pwd = request.data.get('password')

        # 2.数据库校验
        user_object = models.UserInfo.objects.filter(username=user, password=pwd).first()
        if not user_object:
            return Response({'code': 201, "msg": "用户名密码错误"})
        token = str(uuid.uuid4())
        user_object.token = token
        user_object.save()
        return Response({'code': 0, 'data': token})

因此这里会返回MyThrottle的实例化对象,我们知道实例化对象的时候,便会去执行其中的__init__方法,因此找到MyThrottle所继承的SimpleRateThrottle类,其源码如下:

class SimpleRateThrottle(BaseThrottle):
    cache = default_cache
    timer = time.time
    cache_format = 'throttle_%(scope)s_%(ident)s'
    scope = None
    THROTTLE_RATES = api_settings.DEFAULT_THROTTLE_RATES

    def __init__(self):
        if not getattr(self, 'rate', None):
            self.rate = self.get_rate()
        self.num_requests, self.duration = self.parse_rate(self.rate)

由其源码我们可以知道它的__init__方法先是调用了get_rate方法,其源码如下:

    def get_rate(self):
        if not getattr(self, 'scope', None):
            msg = ("You must set either `.scope` or `.rate` for '%s' throttle" %
                   self.__class__.__name__)
            raise ImproperlyConfigured(msg)

        try:
            return self.THROTTLE_RATES[self.scope]
        except KeyError:
            msg = "No default throttle rate set for '%s' scope" % self.scope
            raise ImproperlyConfigured(msg)

首先是判断了实例化对象中是否有scope,没有则报错,然后返回实例化对象中的THROTTLE_RATES字典中的叫"self.scope"的value值,这里的THROTTLE_RATES在SimpleRateThrottle中的定义如下:

THROTTLE_RATES = api_settings.DEFAULT_THROTTLE_RATES

告诉我们应该去settings.py中寻找DEFAULT_THROTTLE_RATES,但是由于前面我们在MyThrottle中已经定义了DEFAULT_THROTTLE_RATES,根据面向对象知识,应该是先读取MyThrottle中的DEFAULT_THROTTLE_RATES,这里需要说明一下,根据大部分开发者的编程习惯,这里最好是在settings.py中定义,因此在此需要修改如下代码:

# ext/throttle.py
from rest_framework.throttling import SimpleRateThrottle
from django.core.cache import cache as default_cache


class MyThrottle(SimpleRateThrottle):
    cache = default_cache
    scope = "xxx"
    # 将这里定义的THROTTLE_RATES注释掉
    # THROTTLE_RATES = {"xxx": "5/m"}

    def get_cache_key(self, request, view):
        if request.user:
            ident = request.user.pk  # 用户ID
        else:
            # 获取请求用户IP(去request中找请求头)
            ident = self.get_ident(request)  

        # throttle_u # throttle_user_11.11.11.11ser_2

        return self.cache_format % {'scope': self.scope, 'ident': ident}
# day13/settings.py
REST_FRAMEWORK = {
    'UNAUTHENTICATED_USER': None,
    'DEFAULT_AUTHENTICATION_CLASSES': [
        'ext.auth.QueryParamsAuthentication',
        'ext.auth.HeaderAuthentication',
        'ext.auth.NoAuthentication',
    ],
    'DEFAULT_PERMISSION_CLASSES': [
        'ext.per.UserPermission',
        'ext.per.ManagerPermission',
        'ext.per.AcePermission',
    ],
    # 在此定义DEFAULT_THROTTLE_RATES
    'DEFAULT_THROTTLE_RATES': {
        'xxx': "5/m"
    },
}

回到上面,现在弄清楚了DEFAULT_THROTTLE_RATES,此时需要返回 "self.scope"的value值,那么self.scope是什么?其定义在MyThrottle中,其值如下:

scope = "xxx"

 因此就是返回DEFAULT_THROTTLE_RATES该字典中key值为"xxx"的value值,即"5/m",因此回到前面的__init__方法,self.rate的值便是"5/m",下一步执行的代码如下:

self.num_requests, self.duration = self.parse_rate(self.rate)

 将self.rate作为参数传入到parse_rate方法,其源码如下:

    def parse_rate(self, rate):
        """
        Given the request rate string, return a two tuple of:
        <allowed number of requests>, <period of time in seconds>
        """
        if rate is None:
            return (None, None)
        num, period = rate.split('/')
        num_requests = int(num)
        duration = {'s': 1, 'm': 60, 'h': 3600, 'd': 86400}[period[0]]
        return (num_requests, duration)

这里首先是将传入的self.rate即"5/m"分隔,将"5"赋值给num并转化为int,将"m"赋值给period,然后同样的手法,获取duration字典的key为"m"的value值,即60,即1m等于60s

然后返回num_requests和duration,此时值分别为5和60

这里的实例化对象的步骤已经走完,我们回到check_throttles,回顾一下源码:

    def check_throttles(self, request):
        """
        Check if request should be throttled.
        Raises an appropriate exception if the request is throttled.
        """
        throttle_durations = []
        for throttle in self.get_throttles():
            if not throttle.allow_request(request, self):
                throttle_durations.append(throttle.wait())

        if throttle_durations:
            # Filter out `None` values which may happen in case of config / rate
            # changes, see #1438
            durations = [
                duration for duration in throttle_durations
                if duration is not None
            ]

            duration = max(durations, default=None)
            self.throttled(request, duration)

现在代码运行到了for throttle in self.get_throttles(),此时的throttle是当前自定义的限流类,接下来是调用throttle中的allow_request方法,判定它的返回值是否为False,如果为False则执行if语句中的代码,allow_request定义在throttle所继承的SimpleRateThrottle类中,其源码如下:

    def allow_request(self, request, view):
        """
        Implement the check to see if the request should be throttled.

        On success calls `throttle_success`.
        On failure calls `throttle_failure`.
        """
        if self.rate is None:
            return True

        self.key = self.get_cache_key(request, view)
        if self.key is None:
            return True

        self.history = self.cache.get(self.key, [])
        self.now = self.timer()

        # Drop any requests from the history which have now passed the
        # throttle duration
        while self.history and self.history[-1] <= self.now - self.duration:
            self.history.pop()
        if len(self.history) >= self.num_requests:
            return self.throttle_failure()
        return self.throttle_success()

首先判断self.rate是否为空,前面提到self.rate等于"5/m",因此不会返回True,接着便是调用get_cache_key方法,这是定义在throttle中的,可以简单理解成为每一个自定义限流组件生成一个位的key值,然后是调用self.cache即redis根据前面生成的唯一key去找到历史的访问记录,然后将现在的时间赋值给self.now

下面的while语句写的是当历史记录的最后一个时间小于当前时间减去duration即60时,便把这个历史记录的最小时间从历史记录中剔除

然后计算当前历史记录总存储的时间个数是否大于num_requests即5,如果大于就返回报错,否则就调用throttle_success方法,其源码如下:

    def throttle_success(self):
        """
        Inserts the current request's timestamp along with the key
        into the cache.
        """
        self.history.insert(0, self.now)
        self.cache.set(self.key, self.history, self.duration)
        return True

就是将现在的时间写入history中并且也写入redis存储,已便下次调用数据读取

以上,便是drf限流组件的源码分析

评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Hemameba

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值