django rest framework系列07-基于IP和用户实现自定义访问频率限制以及源码流程

1、已知rest framework中APIview继承与django的View,且url规则中使用as_view(),方法先执行的是self.dispatch()

2、前面的用户认证,权限认证都是在执行完request封装后执行的self.initial(request, *args, **kwargs),那么可以先猜测下访问频率是不是self.check_throttles(request)

如果不是也没关系,我们重写他也是可以让他是的嘛!

    def initial(self, request, *args, **kwargs):
        """
        Runs anything that needs to occur prior to calling the method handler.
        """
        self.format_kwarg = self.get_format_suffix(**kwargs)

        # Perform content negotiation and store the accepted info on the request
        neg = self.perform_content_negotiation(request)
        request.accepted_renderer, request.accepted_media_type = neg

        # Determine the API version, if versioning is in use.
        version, scheme = self.determine_version(request, *args, **kwargs)
        request.version, request.versioning_scheme = version, scheme

        # Ensure that the incoming request is permitted
        self.perform_authentication(request) #用户登录认证
        self.check_permissions(request)#权限认证
        self.check_throttles(request)

3、进入check_throttles方法内容看看做了什么,

    def check_throttles(self, request):
        """
        Check if request should be throttled.
        Raises an appropriate exception if the request is throttled.
        """
        throttle_durations = []
        for throttle in self.get_throttles():  #通过下面解释可以看出,这里是循环一个列表对象
            if not throttle.allow_request(request, self):#调用每个对象中的allow_request方法
                throttle_durations.append(throttle.wait())

结构跟之前的用户认证和权限一样:如果猜的没错也是去配置文件找,应该也是要给列表生成式:

def get_throttles(self):
    """
    Instantiates and returns the list of throttles that this view uses.
    """
    return [throttle() for throttle in self.throttle_classes]

 果然,一模一样,循环self.throttle_classes,看看指向哪里,

throttle_classes = api_settings.DEFAULT_THROTTLE_CLASSES #果然指向配置文件

那我们先自定义一个throttle_classes列表,写一个类让他有个allow_request方法,写之前也可以看看默认的代码是怎么写的。

鼠标按住ctrl点击allow_request方法,如图:

可以看到默认有三个类有这个方法,传入的参数,可以看到有request,view

进入看看结构:从名字中可以看出这是一个模板,且allow_request内部是没有写的就抛出了一个异常。先不管,看看其他方法,get_ident中有个remote_addr = request.META.get('REMOTE_ADDR'),百度了下是获取IP

那么先不管了,我们也这样获取IP。然后通过自己的方法先实现,后面再来看他的其他内置的方法。

class BaseThrottle:
    """
    Rate throttling of requests.
    """

    def allow_request(self, request, view):
        """
        Return `True` if the request should be allowed, `False` otherwise.
        """
        raise NotImplementedError('.allow_request() must be overridden')

    def get_ident(self, request):
        """
        Identify the machine making the request by parsing HTTP_X_FORWARDED_FOR
        if present and number of proxies is > 0. If not use all of
        HTTP_X_FORWARDED_FOR if it is available, if not use REMOTE_ADDR.
        """
        xff = request.META.get('HTTP_X_FORWARDED_FOR')
        remote_addr = request.META.get('REMOTE_ADDR')
        num_proxies = api_settings.NUM_PROXIES

        if num_proxies is not None:
            if num_proxies == 0 or xff is None:
                return remote_addr
            addrs = xff.split(',')
            client_addr = addrs[-min(num_proxies, len(addrs))]
            return client_addr.strip()

        return ''.join(xff.split()) if xff else remote_addr

    def wait(self):
        """
        Optionally, return a recommended number of seconds to wait before
        the next request.
        """
        return None

回到视图中开始自定义访问频率类。,这里直接return False 看看,但是报错提示我们少一个wait方法,那么我们写一个wait方法什么也不做。然后视图中使用起来。

class myThrottle(object):

    def allow_request(self,request,view):

        return False
    
    def wait(self):
        pass

视图中使用:

class authView(APIView):
    '''
    用于用户登录
    '''
    authentication_classes = []
    throttle_classes = [myThrottle,]
    def post(self,request,*args,**kwargs):

        ret = {'code':1000,'msg':None}
        try:
            user = request._request.POST.get('username')  #从request获取用户提交为什么是这样看下面解释
            pwd = request._request.POST.get('password')
            obj = models.UserInfo.object.filter(username=user,password=pwd).first() #数据库中查找用户
            print(obj)
            if not obj:
                #判断obj是否空如果空则没有查到用户名密码错误!
                ret['code'] = 1001
                ret['msg'] = '用户名密码错误!'
                return JsonResponse(ret)
            #登录成功,生成token
            token = md5(user)
            #保存token 到数据库  updata_or_create方法存在则更新,不存在则创建。
            #user字段等于obj刚才查到的用户,token等于我们通过user+时间戳生成的token
            models.UserToken.object.update_or_create(user=obj,defaults={'token':token})
            ret['token'] = token
            ret['msg'] = '登录成功'
        except Exception as e:
            ret['code'] = 1002
            ret['msg'] = '请求异常'+ str(e)

        return JsonResponse(ret)

运行截图:可以看到是能用了,然后我们再去allow_request方法内实现相应的逻辑。

 allow_request开始实现自定义逻辑:

#访问记录全局变量
Access_records = {}
class myThrottle(object):
    '设置10秒访问3次'
    def allow_request(self,request,view):
        ip = request.META.get('REMOTE_ADDR')
        if ip not in Access_records:
            #如果全局变量中没有IP访问记录则新增一个key为IP,的时间戳列表
            Access_records[ip] = [time.time(),]
            return True
        else:
            #如果不是此IP不是第一次访问,则去全局变量中获取访问记录
            history = Access_records.get(ip)
            print(history)
            print(time.time() - history[-1])
            if history and time.time() - history[-1] > 10:
                #如果记录存在且最有一个记录减去当前时间小于60那么除去最后一条记录
                history.pop()
            if len(history) < 3:
                history.insert(0,time.time())
                return True
        return False

    def wait(self):
        pass

以上代码有参考内置函数,发现一个问题,判断当前时间减去最久那条记录时间,pop()出一个记录,问题是列表中又有三条记录!。那么下一次访问又得判断当前的-2。这里我改了下:

#访问记录全局变量
Access_records = {}
class myThrottle(object):
    '设置10秒访问3次'
    def allow_request(self,request,view):
        ip = request.META.get('REMOTE_ADDR')
        if ip not in Access_records:
            #如果全局变量中没有IP访问记录则新增一个key为IP,的时间戳列表
            Access_records[ip] = [time.time(),]
            return True
        else:
            #如果不是此IP不是第一次访问,则去全局变量中获取访问记录
            history = Access_records.get(ip)
            print(history)
            print(time.time() - history[-1])
            if history and time.time() - history[-1] > 10:
                #如果记录存在且最有一个记录减去当前时间小于60那么除去最后一条记录
                history.clear()
            if len(history) < 3:
                history.insert(0,time.time())
                return True
        return False

    def wait(self):
        pass

现在回到wait函数问题,刚开始没写会报错,可以在源码中看到,返回False时会调用了这个方法:

def check_throttles(self, request):
    """
    Check if request should be throttled.
    Raises an appropriate exception if the request is throttled.
    """
    throttle_durations = []
    for throttle in self.get_throttles():
        if not throttle.allow_request(request, self):
            throttle_durations.append(throttle.wait())  #这里返回False会调用wait方法,返回值加入到throttle_durations列表中

    if throttle_durations: 
        # Filter out `None` values which may happen in case of config / rate
        # changes, see #1438
        durations = [
            duration for duration in throttle_durations
            if duration is not None
        ]

        duration = max(durations, default=None)
        self.throttled(request, duration) #不管怎么样最后调用了一下这个方法。可以先进去看看,抛出异常而已。

进去wait之前看下,返回值加入到了一个列表,。最后取了个最大值丢给了self.throttled方法。我进去看了下就是抛出一个异常,参数是之前传进来的数值。

先不管先进入wait看看做了什么,BaseThrottle没写,我们进入下面那个SimpleRateThrottle看看:

def wait(self):
    """
    Returns the recommended next request time in seconds.
    """
    if self.history: # 这里history我们刚才只是在自己的函数内定义了个变量,看来其他函数要用那么写成类全局的,为什么这么说是我们那个history,见下面代码分析。
        remaining_duration = self.duration - (self.now - self.history[-1])  #self.now跟进去看了下就是time.tiem()。说明就是那个时间戳。 现在的时间戳减去self.history[-1]。也就是还有多少秒
                                                                            #然后这个剩余多少秒又被self.duration减了,跟进去看看
    else:
        remaining_duration = self.duration

    available_requests = self.num_requests - len(self.history) + 1
    if available_requests <= 0:
        return None

    return remaining_duration / float(available_requests)  

self.duration减(当前时间减最早访问的时间),跟进去看看self.duration是什么:

def __init__(self):
    if not getattr(self, 'rate', None):
        self.rate = self.get_rate()
    self.num_requests, self.duration = self.parse_rate(self.rate)#调用了self.parse_rate返回的两个参数先进去看看,为什么先进去看看,因为这个参数我看了是去配置文件找了但是配置文件没写,所以需要根据代码自己定义。

进入后代码:

def parse_rate(self, rate):
    """
    Given the request rate string, return a two tuple of:
    <allowed number of requests>, <period of time in seconds>
    """
    if rate is None:
        return (None, None)
    num, period = rate.split('/')  #传入的参数用/分割
    num_requests = int(num)  #/后的参数被int说明是个数字
    duration = {'s': 1, 'm': 60, 'h': 3600, 'd': 86400}[period[0]]  #取duration[period[0]]那么这个period第一个字符串必须是s、m、h、d。否则没办法取呀。再看看对应的数字,不就是秒,分,时,天,对应的秒嘛!
    return (num_requests, duration) #所以这里返回一个数字,还有一个时间

 假设我们传入了3/m返回应该是: num_requests =3      duration = 60   。那么回到waith函数

def wait(self):
    """
    假设   当前时间-最早访问时间 = 第一次记录间隔时间
           价格50秒
    """
    if self.history:
        #此时remaining_duration   = 60 -50  就是还需要等待10秒
        remaining_duration = self.duration - (self.now - self.history[-1])
    else:
        remaining_duration = self.duration

    #num_requests 一直等于3  len(self.history) 能到这里来肯定是3   3-3+1  == 1
    available_requests = self.num_requests - len(self.history) + 1
    if available_requests <= 0:
        return None

    return remaining_duration / float(available_requests) #也就是说wait返回是我们剩余要等待的时间。且学习到SimpleRateThrottle类的流程

 修改wait的同时发现很多变量需要种类很多地方使用,所以修改后的(访问时间/次数)可以在初始化的时候修改。

#访问记录全局变量
Access_records = {}
class myThrottle(BaseThrottle):
    '设置10秒访问3次'
    def __init__(self):
        self.history =None
        self.now = time.time()
        self.duration =10
        self.num_requests = 3

    def allow_request(self,request,view):
        ip = request.META.get('REMOTE_ADDR')
        if ip not in Access_records:
            #如果全局变量中没有IP访问记录则新增一个key为IP,的时间戳列表
            Access_records[ip] = [self.now,]
            return True
        else:
            #如果不是此IP不是第一次访问,则去全局变量中获取访问记录
            self.history = Access_records.get(ip)
            if self.history and self.now - self.history[-1] > self.duration:
                #如果记录存在且最有一个记录减去当前时间小于60那么除去最后一条记录
                self.history.clear()
            if len(self.history) < self.num_requests:
                self.history.insert(0,self.now)
                return True
        return False

    def wait(self):
        if self.history:
            remaining_duration = self.duration - (self.now - self.history[-1])
        else:
            remaining_duration = self.duration

        print(len(self.history))
        available_requests = self.num_requests - len(self.history) + 1
        if available_requests <= 0:
            return None

        return remaining_duration / float(available_requests)

测试可以使用,现在做个抽离,全局配置。局部其他配置:跟前几篇一样就道理,我就直接下代码了

文件目录:API.utils.throttle.myThrottle

from rest_framework.throttling import BaseThrottle
import time

#访问记录全局变量
Access_records = {}
class myThrottle(BaseThrottle):
    '设置10秒访问3次'
    def __init__(self):
        self.history =None
        self.now = time.time()  #当前时间戳
        self.duration =10  #规定时间内
        self.num_requests = 3  #访问次数

    def allow_request(self,request,view):
        ip = request.META.get('REMOTE_ADDR')
        if ip not in Access_records:
            #如果全局变量中没有IP访问记录则新增一个key为IP,的时间戳列表
            Access_records[ip] = [self.now,]
            return True
        else:
            #如果不是此IP不是第一次访问,则去全局变量中获取访问记录
            self.history = Access_records.get(ip)
            if self.history and self.now - self.history[-1] > self.duration:
                #如果记录存在且最有一个记录减去当前时间小于60那么除去最后一条记录
                self.history.clear()
            if len(self.history) < self.num_requests:
                self.history.insert(0,self.now)
                return True
        return False

    def wait(self):
        if self.history:
            remaining_duration = self.duration - (self.now - self.history[-1])
        else:
            remaining_duration = self.duration

        print(len(self.history))
        available_requests = self.num_requests - len(self.history) + 1
        if available_requests <= 0:
            return None

        return remaining_duration / float(available_requests)


class myThrottle2(BaseThrottle):
    '设置20秒访问5次'
    def __init__(self):
        self.history =None
        self.now = time.time()  #当前时间戳
        self.duration =20  #规定时间内
        self.num_requests = 5  #访问次数

    def allow_request(self,request,view):
        ip = request.META.get('REMOTE_ADDR')
        if ip not in Access_records:
            #如果全局变量中没有IP访问记录则新增一个key为IP,的时间戳列表
            Access_records[ip] = [self.now,]
            return True
        else:
            #如果不是此IP不是第一次访问,则去全局变量中获取访问记录
            self.history = Access_records.get(ip)
            if self.history and self.now - self.history[-1] > self.duration:
                #如果记录存在且最有一个记录减去当前时间小于60那么除去最后一条记录
                self.history.clear()
            if len(self.history) < self.num_requests:
                self.history.insert(0,self.now)
                return True
        return False

    def wait(self):
        if self.history:
            remaining_duration = self.duration - (self.now - self.history[-1])
        else:
            remaining_duration = self.duration

        print(len(self.history))
        available_requests = self.num_requests - len(self.history) + 1
        if available_requests <= 0:
            return None

        return remaining_duration / float(available_requests)
settings.py文件配置全局访问频次:
REST_FRAMEWORK ={
    'DEFAULT_AUTHENTICATION_CLASSES':['API.utils.auth.Authtication',],
    'DEFAULT_THROTTLE_CLASSES':['API.utils.throttle.myThrottle',]
}

视图总局部使用myThrottle2

from rest_framework.views import APIView
from django.http import JsonResponse
from API import models
from API.utils.permission import mypermissions
from API.utils.permission import mypermissions1
from API.utils.throttle import myThrottle2

from API.utils.throttle import myThrottle2


class authView(APIView):
    '''
    用于用户登录
    '''
    authentication_classes = []
    throttle_classes = [myThrottle2,]
    def post(self,request,*args,**kwargs):
        #self.dispatch()
        ret = {'code':1000,'msg':None}
        try:
            user = request._request.POST.get('username')  #从request获取用户提交为什么是这样看下面解释
            pwd = request._request.POST.get('password')
            obj = models.UserInfo.object.filter(username=user,password=pwd).first() #数据库中查找用户
            print(obj)
            if not obj:
                #判断obj是否空如果空则没有查到用户名密码错误!
                ret['code'] = 1001
                ret['msg'] = '用户名密码错误!'
                return JsonResponse(ret)
            #登录成功,生成token
            token = md5(user)
            #保存token 到数据库  updata_or_create方法存在则更新,不存在则创建。
            #user字段等于obj刚才查到的用户,token等于我们通过user+时间戳生成的token
            models.UserToken.object.update_or_create(user=obj,defaults={'token':token})
            ret['token'] = token
            ret['msg'] = '登录成功'
        except Exception as e:
            ret['code'] = 1002
            ret['msg'] = '请求异常'+ str(e)

        return JsonResponse(ret)

 最后保存文件发现一个问题,我们不能用全局变量保存记录,否则重启了就没有了。借鉴下内置方法的办法:下一篇讲。

 

 

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值