装饰器在文件加载时就会被调用
运用在mysql中
import functools
import pymysql
from pymysql.cursors import DictCursor
from DBUtils.PooledDB import PooledDB
class Mysql(object):
__pool = None
def __init__(self):
self._conn = Mysql.__get_conn__()
self._cursor = self._conn.cursor()
@staticmethod
def __get_conn__():
if Mysql.__pool is None:
Mysql.__pool = PooledDB(creator=pymysql, maxconnections=100, mincached=1, blocking=True,
host=conf.DB_HOST, port=conf.DB_PORT, user=conf.DB_USER, password=conf.pwd,
database=conf.DB_NAME, charset="utf8", cursorclass=DictCursor)
return Mysql.__pool.connection()
@property
def connection(self):
return self._conn
@property
def cursor(self):
return self._cursor
def db_wrap(method):
@functools.wraps(method)
def _wrapper(*args, **kwargs):
conn = Mysql().connection
cursor = conn.cursor()
try:
conn.begin()
retval = method(cursor, *args, **kwargs)
conn.commit()
except:
conn.rollback()
raise
finally:
cursor.close()
conn.close()
return retval
return _wrapper
# e.g. param cursor, table_name, {id:str, pre:str, user:str, ...}
@db_wrap
def insert(cursor, table_name, values):
data_format = ""
data = []
if not len(values.keys()):
return 0
for key, value in values.items():
data.append(value)
data_format = ("%s, " * len(values.keys())).rstrip(', ')
sql = "insert into %s(%s) values(%s)" % (table_name, ','.join(values.keys()), data_format)
res = cursor.execute(sql, tuple(data))
return res
# param cursor, table_name, [{},{}, ...]
@db_wrap
def insert_list(cursor, table_name, values, is_license_info=False):
data_format = ""
data = []
keys = []
if not len(values):
return 0
for n in values:
line = []
if not len(n):
return 0
for key, value in n.items():
line.append(value)
data.append(tuple(line))
if not data_format:
data_format = ("%s, " * len(n.keys())).rstrip(', ')
keys = n.keys()
if is_license_info:
sql = "INSERT IGNORE INTO %s(%s) VALUES(%s)" % (table_name, ','.join(keys), data_format)
else:
sql = "INSERT INTO %s(%s) VALUES(%s)" % (table_name, ','.join(keys), data_format)
#
res = cursor.executemany(sql, data)
# sql = sql % data[0]
# res = cursor.execute(sql)
return res
- 运用在django版本控制中
def versioning(origin_object):
"""
接口版本装饰器, 不同版本的服务不需要改变url,只要在请求中定义'X-Api-Version',以及定义不同版本的接口方法
param: origin_object, 被装饰的对象可以是class View或者func
使用方法:
1. 装饰class view: 给class加上@versioning装饰器,
在类中定义各种版本的方法 "[http_method]_[version]",例如 get_v2, post_v3
2. 装饰func: 给和url绑定的func加上@versioning装饰器,
在模块中, 定义该func各种版本的方法 "[func_name]_[version]",同一个方法的不同版本要保证func_name一致
***!!! 注意当func有多个装饰器时,将 @versioning 放在最上面,匹配到正确的方法再去执行其他的decorator !!!***
"""
def dispatch(self, request, *args, **kwargs):
methods = []
for i in dir(origin_object):
if i.startswith(request.method.lower()) and i != request.method.lower():
methods.append(i)
if request.method.lower() in self.http_method_names:
if request.version:
handler = getattr(self, get_adaptive_method(request.version.lower(), request.method.lower(), methods),
self.http_method_not_allowed)
else:
handler = getattr(self, request.method.lower(), self.http_method_not_allowed)
else:
handler = self.http_method_not_allowed
return handler(request, *args, **kwargs)
@wraps(origin_object)
def wrapper(*args, **kwargs):
version = args[0].version
if not version:
return origin_object(*args, **kwargs)
else:
module = origin_object.__module__
module = import_module(module)
members = dir(module)
methods = [member for member in members if
member.startswith(origin_object.__name__) and member != origin_object.__name__]
version_func = get_adaptive_method(version, origin_object.__name__, methods)
if version_func != origin_object.__name__:
func_call = getattr(module, version_func)
return func_call(*args, **kwargs)
else:
return origin_object(*args, **kwargs)
if isfunction(origin_object):
return wrapper
if issubclass(origin_object, View):
origin_object.dispatch = dispatch
return origin_object
def get_adaptive_method(version, http_method, methods):
"""
根据version匹配方法,向下最大匹配的原则,找不到当前版本,就向下取最大版本号的方法
"""
if http_method + '_' + version in methods:
return http_method + '_' + version
methods.sort(key=lambda x: int(x[x.index('v') + 1:]), reverse=True)
for m in methods:
if int(version[1:]) >= int(m[m.index('v') + 1:]):
return m
return http_method
- 运用在django的表单验证中
from django.http import QueryDict
from django.utils.datastructures import MultiValueDic
from functools import wraps
def validate_form(form_class):
"""
:param form_class: django.forms.Form的子类
"""
def decorator(func):
@wraps(func)
def wrapper(*args, **kwargs):
request = args[0]
if not isinstance(request, HttpRequest):
raise Exception("view的第一个参数必须是request, 如果用于class-based view, "
"请使用@method_decorator(require_login)")
form = form_class(request.DATA, request.FILES) # request.DATA 为QueryDict的对象,request.FILES为MultiValueDic的对象
if not form.is_valid():
raise APIException(errorcode.INVALID_ARGS, msg=form.errors)
return func(*args, **kwargs, cleaned_data=form.cleaned_data)
return wrapper
return decorator
eg: views.py
@versioning
class BannerView(View):
@method_decorator(validate_form(BannerListForm))
def get(self, request, cleaned_data):
...
@method_decorator(validate_form(BannerListForm))
def get_v2(self, request, cleaned_data):
...
- 将类的方法缓存到对象属性中(灵感来自与bottle源码)
class cached_property(object):
''' A property that is only computed once per instance and then replaces
itself with an ordinary attribute. Deleting the attribute resets the
property. '''
def __init__(self, func):
self.func = func
def __get__(self, obj, cls):
if obj is None: return self
value = obj.__dict__[self.func.__name__] = self.func(obj)
return value
def callback_demo():
return "helloworld"
class Route(object):
@cached_property
def call(self):
''' The route callback with all plugins applied. This property is
created on demand and then cached to speed up subsequent requests.'''
return self._make_callback()
def _make_callback(self): # 给callback增加插件的的操作对路由的所有对象一致,只需要执行一次
callback = callback_demo
return callback
if __name__ == '__main__':
route = Route()
print(route.__dict__)
print(route.call())
print(route.__dict__)