三大组件的使用和源码分析
认证组件
认证使用
from rest_framework.authentication import BaseAuthentication
# 编写认证类 继承BaseAuthentication
class MyAuthentication(BaseAuthentication):
def authenticate(self, request):
"""
做用户的认证
1. 读取请求的token
2. 校验合法性
3. 返回值
1. 返回元组 认证成功 返回2个值(request.user request.auth)
2. 抛出异常 认证失败 -> 抛出错误信息
3. 返回None 多个类认证 [类1,类2.....] 如果都返回为None 那就是匿名用户
"""
# 请求头获取token
token = request.query_params.get("token")
# 请求头获取token
# token = request.META.get("HTTP_AUTHORIZATION")
if token:
user_token=UserInfo.objects.filter(token=token).first()
# 认证通过
if user_token:
return user_token.user,token
else:
raise AuthenticationFailed('认证失败')
else:
raise AuthenticationFailed('请求地址中需要携带token')
# 解决 response是403的问题
def authenticate_header(self, request):
return "API"
# 使用配置
# 1. 全局配置 在setting.py中配置
REST_FRAMEWORK = {
"DEFAULT_AUTHENTICATION_CLASSES": [
"ext.auth.MyAuthentication",
# "ext.auth.QueryParamsAuthentication",
# "ext.auth.HeaderAuthentication",
# "ext.auth.NoAuthentication",
]
}
# 2. 局部使用,在视图类上写
class API_view(APIView):
# 视图类中添加认证的类
authentication_classes = [MyAuthentication, ]
......
# 3. 局部禁用
class API_view(APIView):
# 视图类中添加认证的类
authentication_classes = []
......
认证源码精读
承接上次 那我们之间看执行流程
APIView --> dispatch(self, request, *args, **kwargs) --> self.initial(request, *args, **kwargs)
-->(认证走这个方法) self.perform_authentication(request) --> (里面就一个 request.user 所以回Request对象中找user)request.user
class Request:
@property
def user(self):
# 第1步. 判断是否有_user
if not hasattr(self, '_user'):
with wrap_attributeerrors():
# 第2步. 一开始我们没有_user所以我们看 self._authenticate()方法
self._authenticate()
return self._user
# 第2步.
def _authenticate(self):
# 2.1 从self.authenticators这里列表中获取对象
for authenticator in self.authenticators:
try:
# 执行对象的authenticator中的方法(这里就相当于执行我们创建的MyAuthentication中的authenticate方法) sellf就是Request对象
user_auth_tuple = authenticator.authenticate(self)
except exceptions.APIException:
self._not_authenticated()
raise
# 2.2 判断user_auth_tuple是否为空
if user_auth_tuple is not None:
self._authenticator = authenticator
# user_auth_tuple返回2个值 然后赋值
self.user, self.auth = user_auth_tuple
return
# 最后我们看看这个 self._not_authenticated()方法干了什么
self._not_authenticated()
def _not_authenticated(self):
self._authenticator = None
# 读取配置中的认证对象 有就实例化 没有就赋值为None
if api_settings.UNAUTHENTICATED_USER:
self.user = api_settings.UNAUTHENTICATED_USER()
else:
self.user = None
if api_settings.UNAUTHENTICATED_TOKEN:
self.auth = api_settings.UNAUTHENTICATED_TOKEN()
else:
self.auth = None
总结一下认证组件干了什么:
- 主要就是判断用户类型 通过认证 返回2个值 然后好进行下一步权限校验 失败就抛出异常
- 认证组件时或的关系 通过一个就能通过
权限组件
权限使用
from rest_framework.permissions import BasePermission
# 写一个类,继承BasePermission,重写has_permission,如果权限通过,就返回True,不通过就返回False
class UserPermission(BasePermission):
message = {"status": False, 'msg': "无权访问1"}
def has_permission(self, request, view):
# 如果该字段用了choice,通过get_字段名_display()就能取出choice后面的中文
if request.user.role == 3:
return True
return False
# 使用配置
# 1. 全局配置 在setting.py中配置
REST_FRAMEWORK = {
"DEFAULT_AUTHENTICATION_CLASSES": [
# "ext.auth.BossPermission",
# "ext.auth.ManagerPermission",
"ext.per.UserPermission",
]
}
# 2. 局部使用,在视图类上写
class API_view(APIView):
# 视图类中添加认证的类
permission_classes = [UserPermission,]
......
# 3. 局部禁用
class API_view(APIView):
# 视图类中添加认证的类
permission_classes = []
......
权限源码精读
APIView --> dispatch(self, request, *args, **kwargs) --> self.initial(request, *args, **kwargs)-->self.check_permissions(request)
class APIView(View):
def check_permissions(self, request):
# 遍历 get_permissions里的权限对象
for permission in self.get_permissions():
# has_permission 返回 True和False
if not permission.has_permission(request, self):
# 没有通过 就用 permission_denied 抛出异常
self.permission_denied(
request,
message=getattr(permission, 'message', None),
code=getattr(permission, 'code', None)
)
权限组件的扩展
- 权限组件是且关系,即A条件 且 B条件 且 C条件,同时满足。
- 那么我们优化一下实现或关系
在自己的视图类中重写 check_permissions()方法
def check_permissions(self, request):
for permission in self.get_permissions():
if permission.has_permission(request, self):
return
else:
self.permission_denied(
request,
message=getattr(permission, 'message', None),
code=getattr(permission, 'code', None)
)
# 如果嫌每个类中都要重新方法 那就直接创建一个类继承
限流组件
限流组件的使用
from rest_framework.throttling import SimpleRateThrottle
from rest_framework import exceptions
from rest_framework import status
class ThrottledException(exceptions.APIException):
status_code = status.HTTP_429_TOO_MANY_REQUESTS
default_code = 'throttled'
# 直接继承SimpleRateThrottle 重写get_cache_key方法就行
class MyThrottle(SimpleRateThrottle):
# 就是可以理解成一个名称 避免重名概率和方便找到对应的限流类
scope = "user"
# 其他:'s', 'sec', 'm', 'min', 'h', 'hour', 'd', 'day' 不建议这种写法 写道全局配置会好一些
# THROTTLE_RATES = {"user": "10/m"}
def get_cache_key(self, request, view):
if request.user:
ident = request.user.pk # 用户ID
else:
# 获取IP 如果是匿名代理其实获取不到真正的IP
ident = self.get_ident(request)
return self.cache_format % {'scope': self.scope, 'ident': ident}
# 重写throttle_failure 自定义错误提示
def throttle_failure(self):
wait = self.wait()
detail = {
"code": 1005,
"data": "访问频率限制",
'detail': "需等待{}s才能访问".format(int(wait))
}
raise ThrottledException(detail)
# 使用配置
# 1. 全局配置 在setting.py中配置
REST_FRAMEWORK = {
# 限流列表
"DEFAULT_AUTHENTICATION_CLASSES": [
"ext.throttle.MyThrottle",
# 'ext.throttle.AnonRateThrottle', # 匿名用户限流
# 'ext.throttle.UserRateThrottle', # 登录用户限流
# 'ext.throttle.ScopedRateThrottle', # 针对某一个接口限流(只能在APIView类使用)
],
# 限制访问频次
'DEFAULT_THROTTLE_RATES': {
'xxx': '1/m',
'user': '10/m'
}
}
# 2. 局部使用,在视图类上写
class API_view(APIView):
# 视图类中添加认证的类
throttle_classes = [MyThrottle, ]
......
# 3. 局部禁用
class API_view(APIView):
# 视图类中添加认证的类
throttle_classes = []
......
限流组件源码精读
APIView --> dispatch(self, request, *args, **kwargs) --> self.initial(request, *args, **kwargs)-->self.check_throttles(request)
class APIView(View):
def check_throttles(self, request):
throttle_durations = []
# 遍历获取限流类的实例化对象
for throttle in self.get_throttles():
### 重点 如何实例化这个对象的 ###
# 调用实例方法
if not throttle.allow_request(request, self):
# 不通过把需要的时间等待的时间添加到列表中
throttle_durations.append(throttle.wait())
# 如果有值把所以等待时间放入durations这个列表中
if throttle_durations:
durations = [
duration for duration in throttle_durations
if duration is not None
]
# 取最大的等待时间 很好理解 如果有多个限流组件 得等待最长时间通过才行
duration = max(durations, default=None)
self.throttled(request, duration)
def throttled(self, request, wait):
# 抛出错误提示 Throttled类中其实可以定义等待时间提示
raise exceptions.Throttled(wait)
# 究竟是怎么实例化的
class BaseThrottle:
# 继承必须重写 就类似java的抽象类
def allow_request(self, request, view):
raise NotImplementedError('.allow_request() must be overridden')
# 获取IP
def get_ident(self, request):
xff = request.META.get('HTTP_X_FORWARDED_FOR')
remote_addr = request.META.get('REMOTE_ADDR')
num_proxies = api_settings.NUM_PROXIES
if num_proxies is not None:
if num_proxies == 0 or xff is None:
return remote_addr
addrs = xff.split(',')
client_addr = addrs[-min(num_proxies, len(addrs))]
return client_addr.strip()
return ''.join(xff.split()) if xff else remote_addr
# 等待时间
def wait(self):
return None
# SimpleRateThrottle继承BaseThrottle扩展了一些
class SimpleRateThrottle(BaseThrottle):
cache = default_cache
timer = time.time
cache_format = 'throttle_%(scope)s_%(ident)s'
scope = None
THROTTLE_RATES = api_settings.DEFAULT_THROTTLE_RATES
def __init__(self):
# 第一次进来走这个
if not getattr(self, 'rate', None):
# 1 get_rate方法
self.rate = self.get_rate()
self.num_requests, self.duration = self.parse_rate(self.rate)
def get_cache_key(self, request, view):
raise NotImplementedError('.get_cache_key() must be overridden')
# 1 get_rate方法
def get_rate(self):
if not getattr(self, 'scope', None):
msg = ("You must set either `.scope` or `.rate` for '%s' throttle" %
self.__class__.__name__)
raise ImproperlyConfigured(msg)
try:
# 这里相当于拿到 THROTTLE_RATES={"user":"5/m"} scope="user"
return self.THROTTLE_RATES[self.scope] # 5/m
except KeyError:
msg = "No default throttle rate set for '%s' scope" % self.scope
raise ImproperlyConfigured(msg)
def parse_rate(self, rate):
if rate is None:
return (None, None)
num, period = rate.split('/')
# 次数
num_requests = int(num)
# 就是个取值 有意思的是[period[0]] 代表 "5/hour" 可以取"h"
duration = {'s': 1, 'm': 60, 'h': 3600, 'd': 86400}[period[0]]
# 返回次数 和 时间
return (num_requests, duration)
# 2 前面源码知道会实例化之后会调用allow_request
def allow_request(self, request, view):
if self.rate is None:
return True
# 通过缓存获取请求中带上的用户唯一标识
self.key = self.get_cache_key(request, view)
if self.key is None:
# 获取不到唯一标识,默认放行
return True
# 获取对应标识用户的访问历史
self.history = self.cache.get(self.key, [])
self.now = self.timer()
# 将已经超过配置周期期限的历史去掉
while self.history and self.history[-1] <= self.now - self.duration:
self.history.pop()
# 判断是否大于限制数量
if len(self.history) >= self.num_requests:
return self.throttle_failure()
return self.throttle_success()
def throttle_success(self):
# 放行,同时这里会把当次请求插入历史并更新缓存
self.history.insert(0, self.now)
self.cache.set(self.key, self.history, self.duration)
return True
def throttle_failure(self):
return False
def wait(self):
if self.history:
# 需要等待的时间 - ( 当前时间 - 最早访问记录 ) => 还需要等待多久
remaining_duration = self.duration - (self.now - self.history[-1])
else:
remaining_duration = self.duration
available_requests = self.num_requests - len(self.history) + 1
if available_requests <= 0:
return None
# 返回等待时间 这里我有一个疑惑为啥要除float(available_requests) 不直接返回 remaining_duration 暂时我也不太清楚
return remaining_duration / float(available_requests)