一,使用
实例化
from flask_wtf import CSRFProtect
csrf = CSRFProtect()
初始化
from flask import Flask
app = Flask(__name__)
...
WTF_CSRF_SECRET_KEY=xxx #设置token 生成salt
...
csrf.init_app(app)
csrf默认对['POST', 'PUT', 'PATCH', 'DELETE']方法进行设置、验证token机制。方法修改可以通过config中设置 WTF_CSRF_METHODS值进行更改。
如果想排除某个api不进验证,可通过csrf.exempt进行装饰。
二,实现机制
1,生成token
csrf机制通过generate_csrf()函数,生成随机数存入session同时通过随机数dump生成token,并以 csrf-token为键或config中配置的名称通过键值对的方式存在g中,在response时放在header或body中
def generate_csrf(secret_key=None, token_key=None):
"""Generate a CSRF token. The token is cached for a request, so multiple
calls to this function will generate the same token.
During testing, it might be useful to access the signed token in
``g.csrf_token`` and the raw token in ``session['csrf_token']``.
:param secret_key: Used to securely sign the token. Default is
``WTF_CSRF_SECRET_KEY`` or ``SECRET_KEY``.
:param token_key: Key where token is stored in session for comparision.
Default is ``WTF_CSRF_FIELD_NAME`` or ``'csrf_token'``.
"""
secret_key = _get_config(
secret_key, 'WTF_CSRF_SECRET_KEY', current_app.secret_key,
message='A secret key is required to use CSRF.'
)
field_name = _get_config(
token_key, 'WTF_CSRF_FIELD_NAME', 'csrf_token',
message='A field name is required to use CSRF.'
)
if field_name not in g:
s = URLSafeTimedSerializer(secret_key, salt='wtf-csrf-token')
if field_name not in session:
session[field_name] = hashlib.sha1(os.urandom(64)).hexdigest()
try:
token = s.dumps(session[field_name])
except TypeError:
session[field_name] = hashlib.sha1(os.urandom(64)).hexdigest()
token = s.dumps(session[field_name])
setattr(g, field_name, token)
return g.get(field_name)
2,验证token
当app接收到request, csrf会从header、form等中找csrf-token,并load后与存在session中的值进行对比验证。主要通过protect()函数实现。
def protect(self):
if request.method not in current_app.config['WTF_CSRF_METHODS']:
return
try:
validate_csrf(self._get_csrf_token())
except ValidationError as e:
logger.info(e.args[0])
self._error_response(e.args[0])
if request.is_secure and current_app.config['WTF_CSRF_SSL_STRICT']:
if not request.referrer:
self._error_response('The referrer header is missing.')
good_referrer = 'https://{0}/'.format(request.host)
if not same_origin(request.referrer, good_referrer):
self._error_response('The referrer does not match the host.')
g.csrf_valid = True # mark this request as CSRF valid
3,csrf.exempt实现机制
注册,当app启动时,对于有exempt装饰的路由(endpoint),会通过一个列表进行记录。
def exempt(self, view):
"""Mark a view or blueprint to be excluded from CSRF protection.
::
@app.route('/some-view', methods=['POST'])
@csrf.exempt
def some_view():
...
::
bp = Blueprint(...)
csrf.exempt(bp)
"""
if isinstance(view, Blueprint):
self._exempt_blueprints.add(view.name)
return view
if isinstance(view, string_types):
view_location = view
else:
if isinstance(view, views.MethodViewType):
view_location = '.'.join((view.__module__, view.__name__.lower()))
else:
view_location = '.'.join((view.__module__, view.__name__))
self._exempt_views.add(view_location)
return view
当路由被访问时,csrf会先检查该函数是否是exempt装饰了的路由,如果是就跳过protect()验证检查。
@app.before_request
def csrf_protect():
if not app.config['WTF_CSRF_ENABLED']:
return
if not app.config['WTF_CSRF_CHECK_DEFAULT']:
return
if request.method not in app.config['WTF_CSRF_METHODS']:
return
if not request.endpoint:
return
if request.blueprint in self._exempt_blueprints:
return
view = app.view_functions.get(request.endpoint)
dest = '{0}.{1}'.format(view.__module__, view.__name__)
if dest in self._exempt_views:
return
self.protect()