DRF-认证-权限-限流组件使用和源码精读

三大组件的使用和源码分析

认证组件

认证使用

from rest_framework.authentication import BaseAuthentication

# 编写认证类 继承BaseAuthentication
class MyAuthentication(BaseAuthentication):
    def authenticate(self, request):
        """
        做用户的认证
        1. 读取请求的token
        2. 校验合法性
        3. 返回值
            1. 返回元组 认证成功 返回2个值(request.user request.auth)
            2. 抛出异常 认证失败 -> 抛出错误信息
            3. 返回None 多个类认证 [类1,类2.....] 如果都返回为None 那就是匿名用户 
        """
        # 请求头获取token
        token = request.query_params.get("token")
        # 请求头获取token
        # token = request.META.get("HTTP_AUTHORIZATION")
        if  token:
            user_token=UserInfo.objects.filter(token=token).first()
            # 认证通过
            if user_token:
                return user_token.user,token
            else:
                raise AuthenticationFailed('认证失败')
        else:
            raise AuthenticationFailed('请求地址中需要携带token')

    # 解决 response是403的问题
    def authenticate_header(self, request):
        return "API"
 
# 使用配置 
# 1. 全局配置 在setting.py中配置
REST_FRAMEWORK = {
        "DEFAULT_AUTHENTICATION_CLASSES": [
            "ext.auth.MyAuthentication",
            # "ext.auth.QueryParamsAuthentication",
            # "ext.auth.HeaderAuthentication",
            # "ext.auth.NoAuthentication",
    ]
}
# 2. 局部使用,在视图类上写
class API_view(APIView):
    # 视图类中添加认证的类
    authentication_classes = [MyAuthentication, ]
    ......
# 3. 局部禁用
class API_view(APIView):
    # 视图类中添加认证的类
    authentication_classes = []
    ......

认证源码精读

承接上次 那我们之间看执行流程
APIView --> dispatch(self, request, *args, **kwargs) --> self.initial(request, *args, **kwargs)
-->(认证走这个方法) self.perform_authentication(request) --> (里面就一个 request.user 所以回Request对象中找user)request.user
class Request:
    @property
    def user(self):
        # 第1步. 判断是否有_user
        if not hasattr(self, '_user'):
            with wrap_attributeerrors():
                 # 第2步. 一开始我们没有_user所以我们看 self._authenticate()方法
                self._authenticate()
        return self._user
    # 第2步.
	def _authenticate(self):
        # 2.1 从self.authenticators这里列表中获取对象
        for authenticator in self.authenticators:
            try:
                # 执行对象的authenticator中的方法(这里就相当于执行我们创建的MyAuthentication中的authenticate方法) sellf就是Request对象
                user_auth_tuple = authenticator.authenticate(self)
            except exceptions.APIException:
                self._not_authenticated()
                raise
			# 2.2 判断user_auth_tuple是否为空
            if user_auth_tuple is not None:
                self._authenticator = authenticator
                # user_auth_tuple返回2个值 然后赋值
                self.user, self.auth = user_auth_tuple
                return
        # 最后我们看看这个 self._not_authenticated()方法干了什么
        self._not_authenticated()
        
    def _not_authenticated(self):
        self._authenticator = None
		# 读取配置中的认证对象 有就实例化 没有就赋值为None
        if api_settings.UNAUTHENTICATED_USER:
            self.user = api_settings.UNAUTHENTICATED_USER()
        else:
            self.user = None

        if api_settings.UNAUTHENTICATED_TOKEN:
            self.auth = api_settings.UNAUTHENTICATED_TOKEN()
        else:
            self.auth = None

总结一下认证组件干了什么:

  • 主要就是判断用户类型 通过认证 返回2个值 然后好进行下一步权限校验 失败就抛出异常
  • 认证组件时或的关系 通过一个就能通过

权限组件

权限使用

from rest_framework.permissions import BasePermission

# 写一个类,继承BasePermission,重写has_permission,如果权限通过,就返回True,不通过就返回False
class UserPermission(BasePermission):
    message = {"status": False, 'msg': "无权访问1"}

    def has_permission(self, request, view):
        # 如果该字段用了choice,通过get_字段名_display()就能取出choice后面的中文
        if request.user.role == 3:
            return True
        return False
    
# 使用配置 
# 1. 全局配置 在setting.py中配置
REST_FRAMEWORK = {
        "DEFAULT_AUTHENTICATION_CLASSES": [
            # "ext.auth.BossPermission",
            # "ext.auth.ManagerPermission",
            "ext.per.UserPermission",
    ]
}
# 2. 局部使用,在视图类上写
class API_view(APIView):
    # 视图类中添加认证的类
	permission_classes = [UserPermission,]
    ......
# 3. 局部禁用
class API_view(APIView):
    # 视图类中添加认证的类
    permission_classes = []
    ......

权限源码精读

APIView --> dispatch(self, request, *args, **kwargs) --> self.initial(request, *args, **kwargs)-->self.check_permissions(request)
class APIView(View):
    def check_permissions(self, request):
        # 遍历 get_permissions里的权限对象 
        for permission in self.get_permissions():
            # has_permission 返回 True和False
            if not permission.has_permission(request, self):
                # 没有通过 就用 permission_denied 抛出异常
                self.permission_denied(
                    request,
                    message=getattr(permission, 'message', None),
                    code=getattr(permission, 'code', None)
                )

权限组件的扩展

  • 权限组件是且关系,即A条件 且 B条件 且 C条件,同时满足。
  • 那么我们优化一下实现或关系
在自己的视图类中重写 check_permissions()方法
    def check_permissions(self, request):
        for permission in self.get_permissions():
            if permission.has_permission(request, self):
                return
		else:
              self.permission_denied(
                    request,
                    message=getattr(permission, 'message', None),
                    code=getattr(permission, 'code', None)
                )
# 如果嫌每个类中都要重新方法 那就直接创建一个类继承

限流组件

限流组件的使用
from rest_framework.throttling import SimpleRateThrottle
from rest_framework import exceptions
from rest_framework import status
class ThrottledException(exceptions.APIException):
    status_code = status.HTTP_429_TOO_MANY_REQUESTS
    default_code = 'throttled'

# 直接继承SimpleRateThrottle 重写get_cache_key方法就行
class MyThrottle(SimpleRateThrottle):
    # 就是可以理解成一个名称 避免重名概率和方便找到对应的限流类
    scope = "user"
    # 其他:'s', 'sec', 'm', 'min', 'h', 'hour', 'd', 'day' 不建议这种写法 写道全局配置会好一些
    # THROTTLE_RATES = {"user": "10/m"}

    def get_cache_key(self, request, view):
        if request.user:
            ident = request.user.pk  # 用户ID
        else:
            # 获取IP 如果是匿名代理其实获取不到真正的IP
            ident = self.get_ident(request)
        return self.cache_format % {'scope': self.scope, 'ident': ident}

    # 重写throttle_failure 自定义错误提示
    def throttle_failure(self):
        wait = self.wait()
        detail = {
            "code": 1005,
            "data": "访问频率限制",
            'detail': "需等待{}s才能访问".format(int(wait))
        }
        raise ThrottledException(detail)
# 使用配置 
# 1. 全局配置 在setting.py中配置
REST_FRAMEWORK = {
    	# 限流列表
        "DEFAULT_AUTHENTICATION_CLASSES": [
            "ext.throttle.MyThrottle",
            # 'ext.throttle.AnonRateThrottle',   # 匿名用户限流
            # 'ext.throttle.UserRateThrottle',   # 登录用户限流
            # 'ext.throttle.ScopedRateThrottle',  # 针对某一个接口限流(只能在APIView类使用)
    ],
    	# 限制访问频次
        'DEFAULT_THROTTLE_RATES': {
        'xxx': '1/m',
        'user': '10/m'
    }
}
# 2. 局部使用,在视图类上写
class API_view(APIView):
    # 视图类中添加认证的类
	throttle_classes = [MyThrottle, ]
    ......
# 3. 局部禁用
class API_view(APIView):
    # 视图类中添加认证的类
    throttle_classes = []
    ......
限流组件源码精读
APIView --> dispatch(self, request, *args, **kwargs) --> self.initial(request, *args, **kwargs)-->self.check_throttles(request)
class APIView(View):
    def check_throttles(self, request):
        throttle_durations = []
        # 遍历获取限流类的实例化对象
        for throttle in self.get_throttles():
            ###  重点 如何实例化这个对象的 ###
            # 调用实例方法
            if not throttle.allow_request(request, self):
                # 不通过把需要的时间等待的时间添加到列表中
                throttle_durations.append(throttle.wait())
		# 如果有值把所以等待时间放入durations这个列表中
        if throttle_durations:
            durations = [
                duration for duration in throttle_durations
                if duration is not None
            ]
		   # 取最大的等待时间 很好理解 如果有多个限流组件 得等待最长时间通过才行
            duration = max(durations, default=None)
            self.throttled(request, duration)
    def throttled(self, request, wait):
		# 抛出错误提示 Throttled类中其实可以定义等待时间提示
        raise exceptions.Throttled(wait)
# 究竟是怎么实例化的
class BaseThrottle:
    # 继承必须重写 就类似java的抽象类
    def allow_request(self, request, view):
        raise NotImplementedError('.allow_request() must be overridden')
	# 获取IP
    def get_ident(self, request):
        xff = request.META.get('HTTP_X_FORWARDED_FOR')
        remote_addr = request.META.get('REMOTE_ADDR')
        num_proxies = api_settings.NUM_PROXIES

        if num_proxies is not None:
            if num_proxies == 0 or xff is None:
                return remote_addr
            addrs = xff.split(',')
            client_addr = addrs[-min(num_proxies, len(addrs))]
            return client_addr.strip()

        return ''.join(xff.split()) if xff else remote_addr
	# 等待时间
    def wait(self):
        return None
# SimpleRateThrottle继承BaseThrottle扩展了一些
class SimpleRateThrottle(BaseThrottle):
    cache = default_cache
    timer = time.time
    cache_format = 'throttle_%(scope)s_%(ident)s'
    scope = None
    THROTTLE_RATES = api_settings.DEFAULT_THROTTLE_RATES

    def __init__(self):
        # 第一次进来走这个
        if not getattr(self, 'rate', None):
            # 1 get_rate方法
            self.rate = self.get_rate()
        self.num_requests, self.duration = self.parse_rate(self.rate)

    def get_cache_key(self, request, view):
        raise NotImplementedError('.get_cache_key() must be overridden')
	# 1 get_rate方法
    def get_rate(self):
        if not getattr(self, 'scope', None):
            msg = ("You must set either `.scope` or `.rate` for '%s' throttle" %
                   self.__class__.__name__)
            raise ImproperlyConfigured(msg)

        try:
            # 这里相当于拿到 THROTTLE_RATES={"user":"5/m"} scope="user"
            return self.THROTTLE_RATES[self.scope] # 5/m
        except KeyError:
            msg = "No default throttle rate set for '%s' scope" % self.scope
            raise ImproperlyConfigured(msg)

    def parse_rate(self, rate):
        if rate is None:
            return (None, None)
        num, period = rate.split('/')
        # 次数
        num_requests = int(num)
        # 就是个取值 有意思的是[period[0]] 代表 "5/hour" 可以取"h"
        duration = {'s': 1, 'm': 60, 'h': 3600, 'd': 86400}[period[0]]
        # 返回次数 和 时间
        return (num_requests, duration)
	# 2 前面源码知道会实例化之后会调用allow_request
    def allow_request(self, request, view):
        if self.rate is None:
            return True
		# 通过缓存获取请求中带上的用户唯一标识
        self.key = self.get_cache_key(request, view)
        if self.key is None:
            # 获取不到唯一标识,默认放行
            return True
		# 获取对应标识用户的访问历史		
        self.history = self.cache.get(self.key, [])
        self.now = self.timer()
		# 将已经超过配置周期期限的历史去掉
        while self.history and self.history[-1] <= self.now - self.duration:
            self.history.pop()
        # 判断是否大于限制数量    
        if len(self.history) >= self.num_requests:
            return self.throttle_failure()
        return self.throttle_success()

    def throttle_success(self):
        # 放行,同时这里会把当次请求插入历史并更新缓存
        self.history.insert(0, self.now)
        self.cache.set(self.key, self.history, self.duration)
        return True

    def throttle_failure(self):
        return False

    def wait(self):
        if self.history:
            # 需要等待的时间 - ( 当前时间 - 最早访问记录 )  => 还需要等待多久
            remaining_duration = self.duration - (self.now - self.history[-1])
        else:
            remaining_duration = self.duration

        available_requests = self.num_requests - len(self.history) + 1
        if available_requests <= 0:
            return None
		# 返回等待时间 这里我有一个疑惑为啥要除float(available_requests) 不直接返回 remaining_duration 暂时我也不太清楚
        return remaining_duration / float(available_requests)
  • 1
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值