1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
 
__author__  =  'Michael Liao'
 
import  asyncio, logging
 
import  aiomysql
 
def  log(sql, args = ()):
     logging.info( 'SQL: %s'  %  sql)
 
#代码分为三部分,第一部分是aiomysql模块的应用
async  def  create_pool(loop,  * * kw):
     logging.info( 'create database connection pool...' )
     global  __pool
     __pool  =  await aiomysql.create_pool(
         host = kw.get( 'host' 'localhost' ),
         port = kw.get( 'port' 3306 ),
         user = kw[ 'user' ],
         password = kw[ 'password' ],
         db = kw[ 'db' ],
         charset = kw.get( 'charset' 'utf8' ),
         autocommit = kw.get( 'autocommit' True ),
         maxsize = kw.get( 'maxsize' 10 ),
         minsize = kw.get( 'minsize' 1 ),
         loop = loop
     )
 
async  def  select(sql, args, size = None ):
     log(sql, args)
     global  __pool
     async with __pool.get() as conn:
         async with conn.cursor(aiomysql.DictCursor) as cur:
             await cur.execute(sql.replace( '?' '%s' ), args  or  ())
             if  size:
                 rs  =  await cur.fetchmany(size)
             else :
                 rs  =  await cur.fetchall()
         logging.info( 'rows returned: %s'  %  len (rs))
         return  rs
 
async  def  execute(sql, args, autocommit = True ):
     log(sql)
     async with __pool.get() as conn:
         if  not  autocommit:
             await conn.begin()
         try :
             async with conn.cursor(aiomysql.DictCursor) as cur:
                 await cur.execute(sql.replace( '?' '%s' ), args)
                 affected  =  cur.rowcount
             if  not  autocommit:
                 await conn.commit()
         except  BaseException as e:
             if  not  autocommit:
                 await conn.rollback()
             raise
         return  affected
 
def  create_args_string(num):
     =  []
     for  in  range (num):
         L.append( '?' )
     return  ', ' .join(L)
 
 
#代码分为三部分,第二部分是orm的实际应用
class  Field( object ):
 
     def  __init__( self , name, column_type, primary_key, default):
         self .name  =  name
         self .column_type  =  column_type
         self .primary_key  =  primary_key
         self .default  =  default
 
     def  __str__( self ):
         return  '<%s, %s:%s>'  %  ( self .__class__.__name__,  self .column_type,  self .name)
 
class  StringField(Field):
 
     def  __init__( self , name = None , primary_key = False , default = None , ddl = 'varchar(100)' ):
         super ().__init__(name, ddl, primary_key, default)
 
class  BooleanField(Field):
 
     def  __init__( self , name = None , default = False ):
         super ().__init__(name,  'boolean' False , default)
 
class  IntegerField(Field):
 
     def  __init__( self , name = None , primary_key = False , default = 0 ):
         super ().__init__(name,  'bigint' , primary_key, default)
 
class  FloatField(Field):
 
     def  __init__( self , name = None , primary_key = False , default = 0.0 ):
         super ().__init__(name,  'real' , primary_key, default)
 
class  TextField(Field):
 
     def  __init__( self , name = None , default = None ):
         super ().__init__(name,  'text' False , default)
 
class  ModelMetaclass( type ):
 
     def  __new__( cls , name, bases, attrs):
         if  name = = 'Model' :
             return  type .__new__( cls , name, bases, attrs)
         tableName  =  attrs.get( '__table__' None or  name
         logging.info( 'found model: %s (table: %s)'  %  (name, tableName))
         mappings  =  dict ()
         fields  =  []
         primaryKey  =  None
         for  k, v  in  attrs.items():
             if  isinstance (v, Field):
                 logging.info( '  found mapping: %s ==> %s'  %  (k, v))
                 mappings[k]  =  v
                 if  v.primary_key:
                     # 找到主键:
                     if  primaryKey:
                         raise  BaseException( 'Duplicate primary key for field: %s'  %  k)
                     primaryKey  =  k
                 else :
                     fields.append(k)
         if  not  primaryKey:
             raise  BaseException( 'Primary key not found.' )
         for  in  mappings.keys():
             attrs.pop(k)
         escaped_fields  =  list ( map ( lambda  f:  '`%s`'  %  f, fields))
         attrs[ '__mappings__' =  mappings  # 保存属性和列的映射关系
         attrs[ '__table__' =  tableName
         attrs[ '__primary_key__' =  primaryKey  # 主键属性名
         attrs[ '__fields__' =  fields  # 除主键外的属性名
         attrs[ '__select__' =  'select `%s`, %s from `%s`'  %  (primaryKey,  ', ' .join(escaped_fields), tableName)
         attrs[ '__insert__' =  'insert into `%s` (%s, `%s`) values (%s)'  %  (tableName,  ', ' .join(escaped_fields), primaryKey, create_args_string( len (escaped_fields)  +  1 ))
         attrs[ '__update__' =  'update `%s` set %s where `%s`=?'  %  (tableName,  ', ' .join( map ( lambda  f:  '`%s`=?'  %  (mappings.get(f).name  or  f), fields)), primaryKey)
         attrs[ '__delete__' =  'delete from `%s` where `%s`=?'  %  (tableName, primaryKey)
         return  type .__new__( cls , name, bases, attrs)
 
 
 
#代码分为三部分,第三部分是调用方法
class  Model( dict , metaclass = ModelMetaclass):
 
     def  __init__( self * * kw):
         super (Model,  self ).__init__( * * kw)
 
     def  __getattr__( self , key):
         try :
             return  self [key]
         except  KeyError:
             raise  AttributeError(r "'Model' object has no attribute '%s'"  %  key)
 
     def  __setattr__( self , key, value):
         self [key]  =  value
 
     def  getValue( self , key):
         return  getattr ( self , key,  None )
 
     def  getValueOrDefault( self , key):
         value  =  getattr ( self , key,  None )
         if  value  is  None :
             field  =  self .__mappings__[key]
             if  field.default  is  not  None :
                 value  =  field.default()  if  callable (field.default)  else  field.default
                 logging.debug( 'using default value for %s: %s'  %  (key,  str (value)))
                 setattr ( self , key, value)
         return  value
 
     @ classmethod
     async  def  findAll( cls , where = None , args = None * * kw):
         ' find objects by where clause. '
         sql  =  [ cls .__select__]
         if  where:
             sql.append( 'where' )
             sql.append(where)
         if  args  is  None :
             args  =  []
         orderBy  =  kw.get( 'orderBy' None )
         if  orderBy:
             sql.append( 'order by' )
             sql.append(orderBy)
         limit  =  kw.get( 'limit' None )
         if  limit  is  not  None :
             sql.append( 'limit' )
             if  isinstance (limit,  int ):
                 sql.append( '?' )
                 args.append(limit)
             elif  isinstance (limit,  tuple and  len (limit)  = =  2 :
                 sql.append( '?, ?' )
                 args.extend(limit)
             else :
                 raise  ValueError( 'Invalid limit value: %s'  %  str (limit))
         rs  =  await select( ' ' .join(sql), args)
         return  [ cls ( * * r)  for  in  rs]
 
     @ classmethod
     async  def  findNumber( cls , selectField, where = None , args = None ):
         ' find number by select and where. '
         sql  =  [ 'select %s _num_ from `%s`'  %  (selectField,  cls .__table__)]
         if  where:
             sql.append( 'where' )
             sql.append(where)
         rs  =  await select( ' ' .join(sql), args,  1 )
         if  len (rs)  = =  0 :
             return  None
         return  rs[ 0 ][ '_num_' ]
 
     @ classmethod
     async  def  find( cls , pk):
         ' find object by primary key. '
         rs  =  await select( '%s where `%s`=?'  %  ( cls .__select__,  cls .__primary_key__), [pk],  1 )
         if  len (rs)  = =  0 :
             return  None
         return  cls ( * * rs[ 0 ])
 
     async  def  save( self ):
         args  =  list ( map ( self .getValueOrDefault,  self .__fields__))
         args.append( self .getValueOrDefault( self .__primary_key__))
         rows  =  await execute( self .__insert__, args)
         if  rows ! =  1 :
             logging.warn( 'failed to insert record: affected rows: %s'  %  rows)
 
     async  def  update( self ):
         args  =  list ( map ( self .getValue,  self .__fields__))
         args.append( self .getValue( self .__primary_key__))
         rows  =  await execute( self .__update__, args)
         if  rows ! =  1 :
             logging.warn( 'failed to update by primary key: affected rows: %s'  %  rows)
 
     async  def  remove( self ):
         args  =  [ self .getValue( self .__primary_key__)]
         rows  =  await execute( self .__delete__, args)
         if  rows ! =  1 :
             logging.warn( 'failed to remove by primary key: affected rows: %s'  %  rows)



本文转自 liqius 51CTO博客,原文链接:http://blog.51cto.com/szgb17/1941168,如需转载请自行联系原作者