# coding: utf-8 import time import copy import sys import logging import threading import mysql_pool __all__ = [ 'Struct', 'ConnectionProxy', 'Hub', ] class Struct(dict): """ - 为字典加上点语法. 例如: >>> o = Struct({'a':1}) >>> o.a >>> 1 >>> o.b >>> None """ def __init__(self, *e, **f): if e: self.update(e[0]) if f: self.update(f) def __getattr__(self, name): # Pickle is trying to get state from your object, and dict doesn't implement it. # Your __getattr__ is being called with "__getstate__" to find that magic method, # and returning None instead of raising AttributeError as it should. if name.startswith('__'): raise AttributeError return self.get(name) def __setattr__(self, name, val): self[name] = val def __delattr__(self, name): self.pop(name, None) def __hash__(self): return id(self) class Executer: def __init__(self, proxy): self.p = proxy self.c = proxy.connect() self.cursor = None def __enter__(self): self.c._lock.acquire() self.cursor = self.c.cursor() return self.cursor def __exit__(self, exc, value, tb): self.p.last_executed = getattr(self.cursor, '_last_executed', None) self.cursor.close() self.c._lock.release() if not self.p.transacting and self.p.get_autocommit(): self.p.close() class ConnectionProxy: def __init__(self, creator): self.creator = creator self.c = None self.transacting = False self.last_executed = None def connect(self): if self.c: return self.c conn = self.creator() conn._lock = threading.Lock() self.c = conn return conn def close(self): if self.c: self.c.close() self.c = None @property def open(self): """return if connection alive""" return self.c is not None and self.c.open def character_set_name(self): return self.connect().character_set_name() def set_character_set(self, charset): return self.connect().set_character_set(charset) def literal(self, s): return self.connect().literal(s) def escape_string(self, s): return self.connect().escape_string(s) def get_autocommit(self): return self.connect().get_autocommit() def autocommit(self, on): return self.connect().autocommit(on) def query(self, command): return self.connect().query(command) def begin(self): return self.query("BEGIN") def commit(self): assert self.c, 'Need connect before commit!' self.c.commit() def rollback(self): assert self.c, 'Need connect before rollback!' self.c.rollback() def fetchall(self, sql, args=None): with Executer(self) as cursor: cursor.execute(sql, args) rows = cursor.fetchall() return rows def fetchone(self, sql, args=None): with Executer(self) as cursor: cursor.execute(sql, args) row = cursor.fetchone() return row def fetchall_dict(self, sql, args=None): with Executer(self) as cursor: cursor.execute(sql, args) fields = [r[0] for r in cursor.description] rows = cursor.fetchall() return [Struct(zip(fields,row)) for row in rows] def fetchone_dict(self, sql, args=None): with Executer(self) as cursor: cursor.execute(sql, args) row = cursor.fetchone() if not row: return fields = [r[0] for r in cursor.description] return Struct(zip(fields, row)) def execute(self, sql, args=None): """ Returns affected rows and lastrowid. """ with Executer(self) as cursor: cursor.execute(sql, args) return cursor.rowcount, cursor.lastrowid def execute_many(self, sql, args=None): """ Execute a multi-row query. Returns affected rows. """ with Executer(self) as cursor: rows = cursor.executemany(sql, args) return rows def callproc(self, procname, *args): """Execute stored procedure procname with args, returns result rows""" with Executer(self) as cursor: cursor.callproc(procname, args) rows = cursor.fetchone() return rows def __enter__(self): """Begin a transaction""" self.transacting = True if self.get_autocommit(): self.begin() self.c._transacting = True return self def __exit__(self, exc, value, tb): """End a transaction""" try: if exc: self.rollback() else: self.commit() finally: self.transacting = False self.c._transacting = False self.close() def __getattr__(self, table_name): return QuerySet(self, table_name) def __str__(self): return '<ConnectionProxy: %x>' % (id(self)) class Hub: """ 用法: >>> db = Hub() >>> db.add_db('default', host='', port=3306, user='', passwd='', db='', charset='utf8', autocommit=True, pool_size=8, wait_timeout=30) >>> db.default.auth_user.get(id=1) :param driver: MySQLdb or pymysql """ def __init__(self, driver): self.pool_manager = mysql_pool.PoolManager(driver) self.creators = {} def add_pool(self, alias, **connect_kwargs): """ :param pool_size: (optional)连接池容量 :param wait_timeout: (optional)连接最大保持时间(秒) """ def creator(): # Timeout before throwing an exception when connecting. # (default: 10, min: 1, max: 31536000) if 'connect_timeout' not in connect_kwargs: connect_kwargs['connect_timeout'] = 10 return self.pool_manager.connect(**connect_kwargs) self.creators[alias] = creator def get_proxy(self, alias): creator = self.creators.get(alias) if creator: return ConnectionProxy(creator) def __getattr__(self, alias): """返回一个库的代理连接""" return self.get_proxy(alias) def __str__(self): return '<Hub: %s>' % id(self) class QuerySet: LOOKUP_SEP = '__' def __init__(self, conn, table_name, db_name=''): "conn: a Connection object" self.conn = conn self.db_name = db_name self.table_name = "%s.%s" % (db_name, table_name) if db_name else table_name self.select_list = [] self.cond_list = [] self.cond_dict = {} self.exclude_list = [] self.exclude_dict = {} self.order_list = [] self.group_list = [] self.ondup_list = [] self.ondup_dict = {} self.having = '' self.limits = [] self.row_style = 0 # Element type, 0:dict, 1:2d list 2:flat list self._result = None self._exists = None self._count = None def literal(self, value): if hasattr(value, '__iter__'): return '(' + ','.join(self.conn.literal(v) for v in value) + ')' return self.conn.literal(value) def escape_string(self, s): if isinstance(s, unicode): charset = self.conn.character_set_name() try: s = s.encode(charset) except: # unknown python encoding pass return self.conn.escape_string(s) def make_select(self, fields): if not fields: return '*' return ','.join(fields) def make_expr(self, key, v): "filter expression" row = key.split(self.LOOKUP_SEP, 1) field = "`%s`" % row[0] op = row[1] if len(row)>1 else '' if not op: if v is None: return field + ' is null' else: return field + '=' + self.literal(v) if op == 'gt': return field + '>' + self.literal(v) elif op == 'gte': return field + '>=' + self.literal(v) elif op == 'lt': return field + '<' + self.literal(v) elif op == 'lte': return field + '<=' + self.literal(v) elif op == 'ne': if v is None: return field + ' is not null' else: return field + '!=' + self.literal(v) elif op == 'in': if not v: return '0' return field + ' in ' + self.literal(v) elif op == 'ni': # not in if not v: return '1' return field + ' not in ' + self.literal(v) elif op == 'startswith': return field + ' like ' + "'%s%%'" % self.escape_string(v) elif op == 'endswith': return field + ' like ' + "'%%%s'" % self.escape_string(v) elif op == 'contains': return field + ' like ' + "'%%%s%%'" % self.escape_string(v) elif op == 'range': return field + ' between ' + "%s and %s" % (self.literal(v[0]), self.literal(v[1])) return key + '=' + self.literal(v) def make_cond(self, args, kw): # field loopup a = ' and '.join('(%s)'%v for v in args) b_list = [self.make_expr(k, v) for k,v in kw.iteritems()] b_list = [s for s in b_list if s] b = ' and '.join(b_list) if a and b: s = a + ' and ' + b elif a: s = a elif b: s = b else: s = '' return s if s else '' def make_where(self, cond_list, cond_dict, exclude_list, exclude_dict): cond = self.make_cond(cond_list, cond_dict) exclude = self.make_cond(exclude_list, exclude_dict) if cond and exclude: return 'where %s and not (%s)' % (cond, exclude) elif cond: return 'where %s' % cond elif exclude: return 'where not (%s)' % exclude return '' def make_order_by(self, fields): if not fields: return '' real_fields = [] for f in fields: if f == '?': f = 'rand()' elif f.startswith('-'): f = f[1:] + ' desc' real_fields.append(f) return 'order by ' + ','.join(real_fields) def reverse_order_list(self): if not self.order_list: self.order_list = ['-id'] else: orders = [] for s in self.order_list: if s == '?': pass elif s.startswith('-'): s = s[1:] else: s = '-' + s orders.append(s) self.order_list = orders def make_group_by(self, fields): if not fields: return '' having = ' having %s'%self.having if self.having else '' return 'group by ' + ','.join(fields) + having def make_limit(self, limits): if not limits: return '' start, stop = limits if not stop: return '' if not start: return 'limit %s' % stop return 'limit %s, %s' % (start, stop-start) def make_query(self, select_list=None, cond_list=None, cond_dict=None, exclude_list=None, exclude_dict=None, group_list=None, order_list=None, limits=None): if select_list is None: select_list = self.select_list if cond_list is None: cond_list = self.cond_list if cond_dict is None: cond_dict = self.cond_dict if exclude_list is None: exclude_list = self.exclude_list if exclude_dict is None: exclude_dict = self.exclude_dict if order_list is None: order_list = self.order_list if group_list is None: group_list = self.group_list if limits is None: limits = self.limits select = self.make_select(select_list) cond = self.make_where(cond_list, cond_dict, exclude_list, exclude_dict) order = self.make_order_by(order_list) group = self.make_group_by(group_list) limit = self.make_limit(limits) sql = "select %s from %s %s %s %s %s" % (select, self.table_name, cond, group, order, limit) return sql @property def sql(self): return self.make_query() def flush(self): if self._result: return self._result sql = self.make_query() if self.row_style == 1: self._result = self.conn.fetchall(sql) elif self.row_style == 2: rows = self.conn.fetchall(sql) vals = [] for row in rows: vals += row self._result = vals else: self._result = self.conn.fetchall_dict(sql) return self._result def clone(self): new = copy.copy(self) new._result = None new._exists = None new._count = None return new def group_by(self, *fields, **kw): q = self.clone() q.group_list += fields q.having = kw.get('having') or '' return q def order_by(self, *fields): q = self.clone() q.order_list = fields return q def select(self, *fields): q = self.clone() q.row_style = 0 if fields: q.select_list = fields return q def values(self, *fields): q = self.clone() q.row_style = 1 if fields: q.select_list = fields return q def flat(self, *fields): q = self.clone() q.row_style = 2 if fields: q.select_list = fields return q def get(self, *args, **kw): cond_dict = dict(self.cond_dict) cond_dict.update(kw) cond_list = self.cond_list + list(args) sql = self.make_query(cond_list=cond_list, cond_dict=cond_dict, limits=(None,1)) if self.row_style == 1: return self.conn.fetchone(sql) else: return self.conn.fetchone_dict(sql) def filter(self, *args, **kw): q = self.clone() q.cond_dict.update(kw) q.cond_list += args return q def exclude(self, *args, **kw): q = self.clone() q.exclude_dict.update(kw) q.exclude_list += args return q def first(self): return self[0] def last(self): return self[-1] def ondup(self, *args, **kw): """ MySQL feature: INSERT...ON DUPLICATE KEY UPDATE... """ q = self.clone() q.ondup_list = args q.ondup_dict = kw return q def create(self, ignore=False, **kw): "Returns lastrowid" tokens = ','.join(['%s']*len(kw)) fields = ["`%s`"%k for k in kw.keys()] fields = ','.join(fields) ignore_s = ' IGNORE' if ignore else '' ondup_s = '' if self.ondup_list or self.ondup_dict: update_fields = self.make_update_fields(self.ondup_list, self.ondup_dict) ondup_s = ' ON DUPLICATE KEY UPDATE ' + update_fields sql = "insert%s into %s (%s) values (%s)%s" % (ignore_s, self.table_name, fields, tokens, ondup_s) _, lastid = self.conn.execute(sql, kw.values()) return lastid def bulk_create(self, obj_list, ignore=False): "Returns affectrows" if not obj_list: return kw = obj_list[0] tokens = ','.join(['%s']*len(kw)) fields = ["`%s`"%k for k in kw.keys()] fields = ','.join(fields) ignore_s = ' IGNORE' if ignore else '' ondup_s = '' if self.ondup_list or self.ondup_dict: update_fields = self.make_update_fields(self.ondup_list, self.ondup_dict) ondup_s = ' ON DUPLICATE KEY UPDATE ' + update_fields sql = "insert%s into %s (%s) values (%s)%s" % (ignore_s, self.table_name, fields, tokens, ondup_s) args = [o.values() for o in obj_list] return self.conn.execute_many(sql, args) def count(self): if self._count is not None: return self._count if self._result is not None: return len(self._result) sql = self.make_query(select_list=['count(*) n'], order_list=[], limits=[None,1]) row = self.conn.fetchone(sql) n = row[0] if row else 0 self._count = n return n def exists(self): if self._result is not None: return True if self._exists is not None: return self._exists sql = self.make_query(select_list=['1'], order_list=[], limits=[None, 1]) row = self.conn.fetchone(sql) b = bool(row) self._exists = b return b def make_update_fields(self, args=[], kw={}): f1 = ', '.join(args) f2 = ', '.join('`%s`=%s'%(k,self.literal(v)) for k,v in kw.iteritems()) if f1 and f2: return f1 + ', ' + f2 elif f1: return f1 return f2 def update(self, *args, **kw): "return affected rows" if not args and not kw: return 0 cond = self.make_where(self.cond_list, self.cond_dict, self.exclude_list, self.exclude_dict) update_fields = self.make_update_fields(args, kw) sql = "update %s set %s %s" % (self.table_name, update_fields, cond) n, _ = self.conn.execute(sql) return n def delete(self, *names): "return affected rows" cond = self.make_where(self.cond_list, self.cond_dict, self.exclude_list, self.exclude_dict) limit = self.make_limit(self.limits) d_names = ','.join(names) sql = "delete %s from %s %s %s" % (d_names, self.table_name, cond, limit) n, _ = self.conn.execute(sql) return n def __iter__(self): rows = self.flush() return iter(rows) def __len__(self): return self.count() def __getitem__(self, k): if self._result is not None: return self._result.__getitem__(k) q = self.clone() if isinstance(k, (int, long)): if k < 0: k = -k - 1 q.reverse_order_list() q.limits = [k, k+1] rows = q.flush() return rows[0] if rows else None elif isinstance(k, slice): start = None if k.start is None else int(k.start) stop = None if k.stop is None else int(k.stop) assert k.step is None, 'Slice step is not supported.' if stop == sys.maxint: stop = None if start and stop is None: stop = self.count() q.limits = [start, stop] return q.flush() def __bool__(self): return self.exists() def __nonzero__(self): # Python 2 compatibility return self.exists() def wait(self, *args, **kw): "扩展: 重复读取从库直到有数据, 表示数据已同步" delays = [0, 0.2, 0.4, 0.8, 1.2, 1.4] for dt in delays: if dt > 0: time.sleep(dt) r = self.get(*args, **kw) if r: return r logging.warning('slave db sync timeout: %s' % self.table_name)
django 数据库查询
最新推荐文章于 2024-05-23 11:53:14 发布