在一個Web App中,所有數據,包括用戶信息、發布的日志、評論等,都存儲在數據庫中。在awesome-python3-webapp中,我們選擇MySQL作為數據庫。
Web App里面有很多地方都要訪問數據庫。訪問數據庫需要創建數據庫連接、游標對象,然后執行SQL語句,最后處理異常,清理資源。這些訪問數據庫的代碼如果分散到各個函數中,勢必無法維護,也不利於代碼復用。
所以,我們要首先把常用的SELECT、INSERT、UPDATE和DELETE操作用函數封裝起來。
由於Web框架使用了基於asyncio的aiohttp,這是基於協程的異步模型。在協程中,不能調用普通的同步IO操作,因為所有用戶都是由一個線程服務的,協程的執行速度必須非常快,才能處理大量用戶的請求。而耗時的IO操作不能在協程中以同步的方式調用,否則,等待一個IO操作時,系統無法響應任何其他用戶。
這就是異步編程的一個原則:一旦決定使用異步,則系統每一層都必須是異步,“開弓沒有回頭箭”。
幸運的是aiomysql
為MySQL數據庫提供了異步IO的驅動。
創建連接池
我們需要創建一個全局的連接池,每個HTTP請求都可以從連接池中直接獲取數據庫連接。使用連接池的好處是不必頻繁地打開和關閉數據庫連接,而是能復用就盡量復用。
連接池由全局變量__pool
存儲,缺省情況下將編碼設置為utf8
,自動提交事務:
@asyncio.coroutine def create_pool(loop, **kw): logging.info('create database connection pool...') global __pool __pool = yield from aiomysql.create_pool( host=kw.get('host', 'localhost'), port=kw.get('port', 3306), user=kw['user'], password=kw['password'], db=kw['db'], charset=kw.get('charset', 'utf8'), autocommit=kw.get('autocommit', True), maxsize=kw.get('maxsize', 10), minsize=kw.get('minsize', 1), loop=loop )
Select
要執行SELECT語句,我們用select
函數執行,需要傳入SQL語句和SQL參數:
@asyncio.coroutine def select(sql, args, size=None): log(sql, args) global __pool with (yield from __pool) as conn: cur = yield from conn.cursor(aiomysql.DictCursor) yield from cur.execute(sql.replace('?', '%s'), args or ()) if size: rs = yield from cur.fetchmany(size) else: rs = yield from cur.fetchall() yield from cur.close() logging.info('rows returned: %s' % len(rs)) return rs
SQL語句的占位符是?
,而MySQL的占位符是%s
,select()
函數在內部自動替換。注意要始終堅持使用帶參數的SQL,而不是自己拼接SQL字符串,這樣可以防止SQL注入攻擊。
注意到yield from
將調用一個子協程(也就是在一個協程中調用另一個協程)並直接獲得子協程的返回結果。
如果傳入size
參數,就通過fetchmany()
獲取最多指定數量的記錄,否則,通過fetchall()
獲取所有記錄。
Insert, Update, Delete
要執行INSERT、UPDATE、DELETE語句,可以定義一個通用的execute()
函數,因為這3種SQL的執行都需要相同的參數,以及返回一個整數表示影響的行數:
@asyncio.coroutine def execute(sql, args): log(sql) with (yield from __pool) as conn: try: cur = yield from conn.cursor() yield from cur.execute(sql.replace('?', '%s'), args) affected = cur.rowcount yield from cur.close() except BaseException as e: raise return affected
execute()
函數和select()
函數所不同的是,cursor對象不返回結果集,而是通過rowcount
返回結果數。
ORM
有了基本的select()
和execute()
函數,我們就可以開始編寫一個簡單的ORM了。
設計ORM需要從上層調用者角度來設計。
我們先考慮如何定義一個User
對象,然后把數據庫表users
和它關聯起來。
from orm import Model, StringField, IntegerField class User(Model): __table__ = 'users' id = IntegerField(primary_key=True) name = StringField()
注意到定義在User
類中的__table__
、id
和name
是類的屬性,不是實例的屬性。所以,在類級別上定義的屬性用來描述User
對象和表的映射關系,而實例屬性必須通過__init__()
方法去初始化,所以兩者互不干擾:
# 創建實例:
user = User(id=123, name='Michael')
# 存入數據庫:
user.insert() # 查詢所有User對象: users = User.findAll()
定義Model
首先要定義的是所有ORM映射的基類Model
:
class Model(dict, metaclass=ModelMetaclass): def __init__(self, **kw): super(Model, self).__init__(**kw) def __getattr__(self, key): try: return self[key] except KeyError: raise AttributeError(r"'Model' object has no attribute '%s'" % key) def __setattr__(self, key, value): self[key] = value def getValue(self, key): return getattr(self, key, None) def getValueOrDefault(self, key): value = getattr(self, key, None) if value is None: field = self.__mappings__[key] if field.default is not None: value = field.default() if callable(field.default) else field.default logging.debug('using default value for %s: %s' % (key, str(value))) setattr(self, key, value) return value
Model
從dict
繼承,所以具備所有dict
的功能,同時又實現了特殊方法__getattr__()
和__setattr__()
,因此又可以像引用普通字段那樣寫:
>>> user['id'] 123 >>> user.id 123
以及Field
和各種Field
子類:
class Field(object): def __init__(self, name, column_type, primary_key, default): self.name = name self.column_type = column_type self.primary_key = primary_key self.default = default def __str__(self): return '<%s, %s:%s>' % (self.__class__.__name__, self.column_type, self.name)
映射varchar
的StringField
:
class StringField(Field): def __init__(self, name=None, primary_key=False, default=None, ddl='varchar(100)'): super().__init__(name, ddl, primary_key, default)
注意到Model
只是一個基類,如何將具體的子類如User
的映射信息讀取出來呢?答案就是通過metaclass:ModelMetaclass
:
class ModelMetaclass(type): def __new__(cls, name, bases, attrs): # 排除Model類本身: if name=='Model': return type.__new__(cls, name, bases, attrs) # 獲取table名稱: tableName = attrs.get('__table__', None) or name logging.info('found model: %s (table: %s)' % (name, tableName)) # 獲取所有的Field和主鍵名: mappings = dict() fields = [] primaryKey = None for k, v in attrs.items(): if isinstance(v, Field): logging.info(' found mapping: %s ==> %s' % (k, v)) mappings[k] = v if v.primary_key: # 找到主鍵: if primaryKey: raise RuntimeError('Duplicate primary key for field: %s' % k) primaryKey = k else: fields.append(k) if not primaryKey: raise RuntimeError('Primary key not found.') for k in mappings.keys(): attrs.pop(k) escaped_fields = list(map(lambda f: '`%s`' % f, fields)) attrs['__mappings__'] = mappings # 保存屬性和列的映射關系 attrs['__table__'] = tableName attrs['__primary_key__'] = primaryKey # 主鍵屬性名 attrs['__fields__'] = fields # 除主鍵外的屬性名 # 構造默認的SELECT, INSERT, UPDATE和DELETE語句: attrs['__select__'] = 'select `%s`, %s from `%s`' % (primaryKey, ', '.join(escaped_fields), tableName) attrs['__insert__'] = 'insert into `%s` (%s, `%s`) values (%s)' % (tableName, ', '.join(escaped_fields), primaryKey, create_args_string(len(escaped_fields) + 1)) attrs['__update__'] = 'update `%s` set %s where `%s`=?' % (tableName, ', '.join(map(lambda f: '`%s`=?' % (mappings.get(f).name or f), fields)), primaryKey) attrs['__delete__'] = 'delete from `%s` where `%s`=?' % (tableName, primaryKey) return type.__new__(cls, name, bases, attrs)
這樣,任何繼承自Model的類(比如User),會自動通過ModelMetaclass掃描映射關系,並存儲到自身的類屬性如__table__
、__mappings__
中。
然后,我們往Model類添加class方法,就可以讓所有子類調用class方法:
class Model(dict): ... @classmethod @asyncio.coroutine def find(cls, pk): ' find object by primary key. ' rs = yield from select('%s where `%s`=?' % (cls.__select__, cls.__primary_key__), [pk], 1) if len(rs) == 0: return None return cls(**rs[0])
User類現在就可以通過類方法實現主鍵查找:
user = yield from User.find('123')
往Model類添加實例方法,就可以讓所有子類調用實例方法:
class Model(dict): ... @asyncio.coroutine def save(self): args = list(map(self.getValueOrDefault, self.__fields__)) args.append(self.getValueOrDefault(self.__primary_key__)) rows = yield from execute(self.__insert__, args) if rows != 1: logging.warn('failed to insert record: affected rows: %s' % rows)
這樣,就可以把一個User實例存入數據庫:
user = User(id=123, name='Michael') yield from user.save()
最后一步是完善ORM,對於查找,我們可以實現以下方法:
-
findAll() - 根據WHERE條件查找;
-
findNumber() - 根據WHERE條件查找,但返回的是整數,適用於
select count(*)
類型的SQL。
以及update()
和remove()
方法。
所有這些方法都必須用@asyncio.coroutine
裝飾,變成一個協程。
調用時需要特別注意:
user.save()
沒有任何效果,因為調用save()
僅僅是創建了一個協程,並沒有執行它。一定要用:
yield from user.save()
才真正執行了INSERT操作。
最后看看我們實現的ORM模塊一共多少行代碼?累計不到300多行。用Python寫一個ORM是不是很容易呢?