首先pip install djangorestframework-jwt
在配置项中导入自己的sso模块
AUTH_USER_MODEL = "bmiss_sso.User"
REST_FRAMEWORK = {
"DEFAULT_AUTHENTICATION_CLASSES": (
"bmiss_sso.authentication.SSOTokenAuthentication",
),
"DEFAULT_PERMISSION_CLASSES": ("rest_framework.permissions.IsAuthenticated",),
...
}
SIMPLE_JWT = {
"AUTH_HEADER_TYPES": ("Token",),
"ACCESS_TOKEN_LIFETIME": timedelta(minutes=30),
}
AUTH_USER_MODEL使用自己的user覆盖Django自带的user
REST_FRAMEWORK在此添加自己的认证类
DEFAULT_AUTHENTICATION_CLASSES认证
DEFAULT_PERMISSION_CLASSES权限
SIMPLE_JWT设置token过期时间,以及请求头
在上述项目中,我的sso认证在bmiss_sso的app中的authentication.py文件中SSOTokenAuthentication类中实现
authentication.py
from rest_framework_simplejwt.authentication import JWTAuthentication
_auth_ctx = threading.local()
class SSOTokenAuthentication(JWTAuthentication):
def authenticate(self, request):
result = super(SSOTokenAuthentication, self).authenticate(request)
if result is not None:
user = result[0]
setattr(_auth_ctx, "user_id", user.id)
return result
在这里继承JWTAuthentication类,即JWT验证类
重写authenticate方法,将user.id赋值给user_id
urls.py
from django.urls import include, path
from rest_framework.routers import DefaultRouter
from bmiss_sso.endpoints.token import (
KnowledgeBaseIDTokenAPIView,
TextileEcoSysIDTokenAPIView,
TokenRefreshView,
)
from bmiss_sso.endpoints.user import UserViewSet
router = DefaultRouter()
router.register(r"users", UserViewSet)
app_name = "bmiss_sso"
urlpatterns = [
path(r"", include(router.urls)),
path("token/refresh/", TokenRefreshView.as_view(), name="sso-token-refresh"),
path(
r"knowledge_base/token/",
KnowledgeBaseIDTokenAPIView.as_view(),
name="sso-knowledge-base-token",
),
path(
r"textile_ecosystem/token/",
TextileEcoSysIDTokenAPIView.as_view(),
name="sso-textile-ecosystem-token",
),
]
token/refresh/ 作用为刷新access token实现为
class TokenRefreshView(jwt_views.TokenRefreshView):
@swagger_auto_schema(
operation_summary="刷新access token", responses={200: TokenRefreshResSerializer()}
)
def post(self, request, *args, **kwargs):
return super(TokenRefreshView, self).post(request, *args, **kwargs)
knowledge_base/token/作用为认证knowledge_base的app的token实现为
class KnowledgeBaseIDTokenAPIView(IDTokenBaseAPIView):
"""
KnowledgeBase Token获取
"""
public_key = KB_PUBLIC_KEY
def get_user_type(self) -> int:
return UserType.COMPANY
@swagger_auto_schema(
operation_summary="KnowledgeBase Token获取",
query_serializer=IDTokenSerializer(),
responses={200: TokenReqSerializer()},
)
def get(self, request):
return super(KnowledgeBaseIDTokenAPIView, self).get(request)
textile_ecosystem/token/作用为认证textile_ecosystem的app的token实现为
class TextileEcoSysIDTokenAPIView(IDTokenBaseAPIView):
"""
纺织生态圈Token获取
"""
public_key = TE_PUBLIC_KEY
def get_user_type(self) -> int:
return UserType.COMPANY
@swagger_auto_schema(
operation_summary="纺织生态圈Token获取",
query_serializer=IDTokenSerializer(),
responses={200: TokenReqSerializer()},
)
def get(self, request):
return super(TextileEcoSysIDTokenAPIView, self).get(request)
在实现时,两个类都继承了IDTokenBaseAPIView,即在认证时所使用的基类,实现为
class IDTokenBaseAPIView(BaseAPIView):
"""
Token获取
不同的应用都会在浙江政务服务中心注册并有各自的PUBLIC_KEY对应的`id_token`,
所以需要用对应的PUBLIC_KEY解析,但是获得的access token是整个项目可用的
"""
permission_classes = ()
authentication_classes = ()
public_key = None
def get_user_type(self) -> int:
raise NotImplementedError
def get_or_create_user(self, user_info: dict) -> User:
user_type = self.get_user_type()
try:
user = User.objects.get(username=user_info["username"])
except User.DoesNotExist:
extend_fields = user_info.get("extendFields")
if isinstance(extend_fields, dict):
company_name = extend_fields.get("companyName")
company_address = extend_fields.get("companyAddress")
else:
company_name = None
company_address = None
user = User.objects.create(
username=user_info["username"],
type=user_type,
name=user_info["name"],
email=user_info.get("email") or "",
mobile=user_info.get("mobile"),
external_id=user_info.get("externalId"),
ou_id=user_info.get("ouId"),
ou_name=user_info.get("ouName"),
zzjw_user_type=user_info.get("userType"),
company_name=company_name,
company_address=company_address,
)
return user
def get(self, request):
q_ser = IDTokenSerializer(data=request.query_params)
q_ser.is_valid(True)
id_token = q_ser.validated_data["id_token"]
user_info = get_user_info(id_token, self.public_key)
user = self.get_or_create_user(user_info)
refresh = RefreshToken.for_user(user)
data = {
"refresh": str(refresh),
"access": str(refresh.access_token),
"name": user.name,
"company_name": user.company_name,
"user_type": user.type,
"user_id": user.id,
}
serializer = TokenReqSerializer(data=data)
serializer.is_valid(True)
if api_settings.UPDATE_LAST_LOGIN:
update_last_login(None, user)
return Response(serializer.validated_data)
在基类中进行id_token的认证,如果没有此用户的时候创建用户,有的时候返回data信息
get_user_info函数
def get_user_info(id_token, public_key):
try:
algo = RSAAlgorithm(RSAAlgorithm.SHA256)
public_key = algo.prepare_key(public_key)
token_info = jwt.decode(
force_bytes(id_token), public_key, verify=True, algorithms=["RS256"]
)
except Exception:
raise exceptions.AuthenticationFailed("令牌认证失败")
user_info = json.loads(json.dumps(token_info))
return user_info
在此函数中。使用token与public_key进行用户认证,进行jwt解析,返回解析出来的用户信息
get_or_create_user函数
def get_or_create_user(self, user_info: dict) -> User:
user_type = self.get_user_type()
try:
user = User.objects.get(username=user_info["username"])
except User.DoesNotExist:
extend_fields = user_info.get("extendFields")
if isinstance(extend_fields, dict):
company_name = extend_fields.get("companyName")
company_address = extend_fields.get("companyAddress")
else:
company_name = None
company_address = None
user = User.objects.create(
username=user_info["username"],
type=user_type,
name=user_info["name"],
email=user_info.get("email") or "",
mobile=user_info.get("mobile"),
external_id=user_info.get("externalId"),
ou_id=user_info.get("ouId"),
ou_name=user_info.get("ouName"),
zzjw_user_type=user_info.get("userType"),
company_name=company_name,
company_address=company_address,
)
return user
在此函数中,对解析出来的用户信息进行判断,判断此时用户表中是否有此用户,有的话就返回查询出来的用户对象,没有就创建后返回
最后在refresh = RefreshToken.for_user(user)使用此方法将此令牌添加到未完成的令牌列表中。最终返回
data = {
"refresh": str(refresh),
"access": str(refresh.access_token),
"name": user.name,
"company_name": user.company_name,
"user_type": user.type,
"user_id": user.id,
}