性感的装饰器

装饰器在文件加载时就会被调用
运用在mysql中
import functools

import pymysql
from pymysql.cursors import DictCursor
from DBUtils.PooledDB import PooledDB


class Mysql(object):
    __pool = None

    def __init__(self):
        self._conn = Mysql.__get_conn__()
        self._cursor = self._conn.cursor()

    @staticmethod
    def __get_conn__():
        if Mysql.__pool is None:
            Mysql.__pool = PooledDB(creator=pymysql, maxconnections=100, mincached=1, blocking=True,
                                    host=conf.DB_HOST, port=conf.DB_PORT, user=conf.DB_USER, password=conf.pwd,
                                    database=conf.DB_NAME, charset="utf8", cursorclass=DictCursor)
        return Mysql.__pool.connection()

    @property
    def connection(self):
        return self._conn

    @property
    def cursor(self):
        return self._cursor

def db_wrap(method):
    @functools.wraps(method)
    def _wrapper(*args, **kwargs):
        conn = Mysql().connection
        cursor = conn.cursor()
        try:
            conn.begin()
            retval = method(cursor, *args, **kwargs)
            conn.commit()
        except:
            conn.rollback()
            raise
        finally:
            cursor.close()
            conn.close()
        return retval

    return _wrapper

# e.g. param cursor, table_name, {id:str, pre:str, user:str, ...}
@db_wrap
def insert(cursor, table_name, values):
    data_format = ""
    data = []

    if not len(values.keys()):
        return 0

    for key, value in values.items():
        data.append(value)
    data_format = ("%s, " * len(values.keys())).rstrip(', ')

    sql = "insert into %s(%s) values(%s)" % (table_name, ','.join(values.keys()), data_format)
    res = cursor.execute(sql, tuple(data))
    return res

# param cursor, table_name, [{},{}, ...]
@db_wrap
def insert_list(cursor, table_name, values, is_license_info=False):
    data_format = ""
    data = []
    keys = []

    if not len(values):
        return 0

    for n in values:
        line = []
        if not len(n):
            return 0
        for key, value in n.items():
            line.append(value)
        data.append(tuple(line))
        if not data_format:
            data_format = ("%s, " * len(n.keys())).rstrip(', ')
            keys = n.keys()

    if is_license_info:
        sql = "INSERT IGNORE INTO  %s(%s) VALUES(%s)" % (table_name, ','.join(keys), data_format)
    else:
        sql = "INSERT INTO  %s(%s) VALUES(%s)" % (table_name, ','.join(keys), data_format)
    #
    res = cursor.executemany(sql, data)
    # sql = sql % data[0]
    # res = cursor.execute(sql)
    return res
  • 运用在django版本控制中
def versioning(origin_object):
    """
    接口版本装饰器, 不同版本的服务不需要改变url,只要在请求中定义'X-Api-Version',以及定义不同版本的接口方法
    param: origin_object, 被装饰的对象可以是class View或者func

    使用方法:
        1. 装饰class view: 给class加上@versioning装饰器,
           在类中定义各种版本的方法 "[http_method]_[version]",例如 get_v2, post_v3

        2. 装饰func: 给和url绑定的func加上@versioning装饰器,
           在模块中, 定义该func各种版本的方法 "[func_name]_[version]",同一个方法的不同版本要保证func_name一致

           ***!!! 注意当func有多个装饰器时,将 @versioning 放在最上面,匹配到正确的方法再去执行其他的decorator !!!***
    """

    def dispatch(self, request, *args, **kwargs):
        methods = []
        for i in dir(origin_object):
            if i.startswith(request.method.lower()) and i != request.method.lower():
                methods.append(i)

        if request.method.lower() in self.http_method_names:
            if request.version:
                handler = getattr(self, get_adaptive_method(request.version.lower(), request.method.lower(), methods),
                                  self.http_method_not_allowed)
            else:
                handler = getattr(self, request.method.lower(), self.http_method_not_allowed)
        else:
            handler = self.http_method_not_allowed
        return handler(request, *args, **kwargs)

    @wraps(origin_object)
    def wrapper(*args, **kwargs):
        version = args[0].version
        if not version:
            return origin_object(*args, **kwargs)
        else:
            module = origin_object.__module__
            module = import_module(module)
            members = dir(module)
            methods = [member for member in members if
                       member.startswith(origin_object.__name__) and member != origin_object.__name__]

            version_func = get_adaptive_method(version, origin_object.__name__, methods)

            if version_func != origin_object.__name__:
                func_call = getattr(module, version_func)
                return func_call(*args, **kwargs)
            else:
                return origin_object(*args, **kwargs)

    if isfunction(origin_object):
        return wrapper

    if issubclass(origin_object, View):
        origin_object.dispatch = dispatch
        return origin_object


def get_adaptive_method(version, http_method, methods):
    """
    根据version匹配方法,向下最大匹配的原则,找不到当前版本,就向下取最大版本号的方法
    """
    if http_method + '_' + version in methods:
        return http_method + '_' + version

    methods.sort(key=lambda x: int(x[x.index('v') + 1:]), reverse=True)

    for m in methods:
        if int(version[1:]) >= int(m[m.index('v') + 1:]):
            return m
    return http_method
  • 运用在django的表单验证中
from django.http import QueryDict
from django.utils.datastructures import MultiValueDic
from functools import wraps


def validate_form(form_class):
    """
    :param form_class: django.forms.Form的子类
    """

    def decorator(func):
        @wraps(func)
        def wrapper(*args, **kwargs):
            request = args[0]
            if not isinstance(request, HttpRequest):
                raise Exception("view的第一个参数必须是request, 如果用于class-based view, "
                                "请使用@method_decorator(require_login)")
            form = form_class(request.DATA, request.FILES) # request.DATA 为QueryDict的对象,request.FILES为MultiValueDic的对象
            if not form.is_valid():
                raise APIException(errorcode.INVALID_ARGS, msg=form.errors)
            return func(*args, **kwargs, cleaned_data=form.cleaned_data)

        return wrapper

    return decorator

eg: views.py

@versioning
class BannerView(View):
    @method_decorator(validate_form(BannerListForm))
    def get(self, request, cleaned_data):
        ...
    
    @method_decorator(validate_form(BannerListForm))
    def get_v2(self, request, cleaned_data):
        ...
  • 将类的方法缓存到对象属性中(灵感来自与bottle源码)
class cached_property(object):
    ''' A property that is only computed once per instance and then replaces
        itself with an ordinary attribute. Deleting the attribute resets the
        property. '''

    def __init__(self, func):
        self.func = func

    def __get__(self, obj, cls):
        if obj is None: return self
        value = obj.__dict__[self.func.__name__] = self.func(obj)
        return value


def callback_demo():
    return "helloworld"


class Route(object):
    @cached_property
    def call(self):
        ''' The route callback with all plugins applied. This property is
            created on demand and then cached to speed up subsequent requests.'''
        return self._make_callback()

    def _make_callback(self):  # 给callback增加插件的的操作对路由的所有对象一致,只需要执行一次
        callback = callback_demo
        return callback
        
if __name__ == '__main__':
    route = Route()
    print(route.__dict__)
    print(route.call())
    print(route.__dict__)

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值