参考 FastApi 、SpringMVC中参数的处理,编写了一个简单的使用 pydantic 对 Flask 中的参数进行校验并且包装为 pydantic 中的 BaseModel类 的工具,将参数映射到 BaseModel 中,简化参数校验以及对参数的包装。
# -*- coding: utf-8 -*-
# @Time : 2024/5/18 15:58
# @File : args_injection.py
"""
注入依赖的装饰器
"""
import inspect
from functools import wraps
from http import HTTPStatus
from typing import Type
from flask import Flask
from flask import request
from pydantic import BaseModel
from pydantic import ValidationError
class InjectionException(Exception):
"""
注入阶段自定义的异常,初始化后该异常会被捕获且直接返回 400 状态码
"""
def __init__(self, msg: str = HTTPStatus.BAD_REQUEST.phrase, code: int = HTTPStatus.BAD_REQUEST.value):
self.msg = msg
self.code = code
class ExceptionWrapper(Exception):
"""
捕获依赖注入阶段所有的异常,并且包装为 ExceptionWrapper
"""
def __init__(self, e):
self.exception = e
class RequestBody:
"""
将请求体解析到一个pydantic的BaseModel,目前只支持 application/x-www-form-urlencoded、application/json、
multipart/form-data 请求体类型的反序列化注入
类型传递方式有两种: \n
参数传递:user=RequestBody(User) \n
注解传递:user:User=RequestBody()
"""
def __init__(self, param_type: Type[BaseModel] = None):
self.param_type = param_type
class RequestParams:
"""
将查询全部参数映射到一个pydantic的BaseModel
类型传递方式有两种: \n
参数传递:user=RequestParams(User) \n
注解传递:user:User=RequestParams()
"""
def __init__(self, param_type: Type[BaseModel] = None):
self.param_type = param_type
def injection():
"""
标记该装饰器的函数(function而不是method)在请求到来时会将函数中默认值为 RequestParams、RequestBody 类型的参数注入到函数参数中。
注入的参数类型必须是 pydantic.BaseModel 类型的子类,注入时候会自动进行 pydantic.BaseModel 的属性校验。
"""
def decorator(f):
@wraps(f)
def wrapper(*args, **kwargs):
if not inspect.ismethod(f) and not inspect.isfunction(f):
raise TypeError(f"decorated object {f.__name__} must be a method or function.")
return _parse_request(f, *args, **kwargs)
return wrapper
return decorator
def _parse_request(_f, *args, **kwargs):
_func_params = inspect.signature(_f).parameters
for _name, _info in _func_params.items():
if _info.default is inspect.Parameter.empty:
continue
_default = _info.default
try:
if isinstance(_default, RequestParams): # 解析查询参数
kwargs[_name] = _parse_query(_f, _default, _name, _info)
elif isinstance(_default, RequestBody): # 解析请求体
kwargs[_name] = _parse_body(_f, _default, _name, _info)
except ValidationError as e:
raise ExceptionWrapper(e)
except InjectionException as e:
raise ExceptionWrapper(e)
return _f(*args, **kwargs)
def _parse_body(_f, _default: RequestBody, _name: str, _info):
"""
解析请求体
:param _f: 被装饰的函数
:param _default: 默认值
:param _name: 参数名
:param _info: 参数信息
:return:
"""
_type = _default.param_type
_type = _default.param_type if _default.param_type else (
_info.annotation if _info.annotation is not inspect.Parameter.empty else None)
if not _type or not issubclass(_type, BaseModel):
raise TypeError(
f"Function '{_f.__name__}' parameter '{_name}' must be a subclass of pydantic.BaseModel but got {_type}")
_content_type: str = request.headers.get("Content-Type", "")
if _content_type == 'application/json':
return _type.model_validate(request.json)
elif (_content_type.startswith("multipart/form-data")
or _content_type.startswith("application/x-www-form-urlencoded")):
return _type.model_validate(request.form.to_dict())
elif not _content_type:
raise InjectionException(f"missing content-type")
raise InjectionException(f"unsupported content-type: {_content_type}")
def _parse_query(_f, _default: RequestParams, _name: str, _info):
"""
解析查询参数
:param _f: 被装饰的函数
:param _default: 默认值
:param _name: 参数名
:param _info: 参数信息
:return:
"""
_type = _default.param_type
_type = _default.param_type if _default.param_type else (
_info.annotation if _info.annotation is not inspect.Parameter.empty else None)
if not _type or not issubclass(_type, BaseModel):
raise TypeError(
f"Function '{_f.__name__}' parameter '{_name}' must be a subclass of pydantic.BaseModel but got {_type}")
return _type.model_validate(request.args.to_dict())
def init_args_injection(app: Flask, display_detail: bool = False):
"""
初始化参数注入工具,在 app 运行之前调用,否则参数注入失败会抛出异常
"""
@app.errorhandler(ExceptionWrapper)
def handle_parameter_exception(e: ExceptionWrapper):
if isinstance(e.exception, ValidationError):
errors = e.exception.errors()
if display_detail:
return [
{
"type": e.get("type", None),
"field": e.get("loc", None)[0] if e.get("loc") else None,
"msg": e.get("msg", None),
}
for e in errors
], 400
else:
return {"msg": HTTPStatus.BAD_REQUEST.phrase}, 400
if isinstance(e.exception, InjectionException):
return {"msg": e.exception.msg}, e.exception.code
def main():
from pydantic import constr
app = Flask(__name__)
app.config['JSON_AS_ASCII'] = False
app.config['JSON_SORT_KEYS'] = False
init_args_injection(app)
class User(BaseModel):
name: constr(min_length=3)
age: constr(min_length=3)
@app.post('/test')
@args_injection()
def test(user: User = RequestBody(User)):
print(user)
return user.name
app.run(host='0.0.0.0', port=8080, debug=True)
if __name__ == '__main__':
main()
- 代码中 args_injection 装饰器装饰函数表示对 test 函数中的参数进行解析尝试注入。
- test 函数中参数 user: User = RequestBody(User) 后面的 RequestBody表示从请求体中获取参数
- 请求入参类继承BaseModel,对应的约束规则需要查看 pydantic 的文档。
- 参数校验不通过会执行 init_args_injection 中定义的异常处理,可以根据需要自行修改。
- _parse_body 函数主要映射表单数据和json类型的请求体数据。 _parse_body 函数中不解析 multipart/form-data 表单中的文件,文件读取需要在接口从request中自行获取,也可根据需要修改上述代码
- ……