修改sqlarchemy源碼使其支持jdbc連接mysql


注意:本文不會將所有完整源碼貼出,只是將具體的思路以及部分源碼貼出,需要感興趣的讀者自己實驗然后實現吆。 

緣起

  公司最近的項目需要將之前的部分業務的數據庫連接方式改為jdbc,但由於之前的項目都使用sqlarchemy作為orm框架,該框架似乎沒有支持jdbc,為了能做最小的修改並滿足需求,所以需要修改sqlarchemy的源碼。

基本配置介紹

  sqlalchemy 版本:1.1.15

  使用jaydebeapi模塊調用jdbc連接mysql

前提:

  1 學會使用jaydebeapi模塊,使用方法具體可以參考:

    https://pypi.python.org/pypi/JayDeBeApi

    介紹的比較詳細的可以參考:http://shuaizki.github.io/language_related/2013/06/22/introduction-to-jpype.html

     jaydebeapi是一個基於jpype的在Cpython中可以通過jdbc連接數據庫的模塊。該模塊的python代碼很少,基本上可以分為連接部分、游標部分、結果轉換部分這三個。一般來說我們可能需要修改的就是結果轉換部分,比如說sqlalchemy查詢時如果某條記錄中含TIME字段,那么該字段一般要表現為timedelta對象。而在jaydebeapi中則返回的是字符串對象,這樣在sqlalchemy中會報錯的。

sqlarchemy為我們實現了ORM對象與語句的轉換,連接池,session(包括對線程的支持scope_session)等較為上層的邏輯,但這些東西在這里我們不需要考慮(當然創建一個連接,生成curcor還是要考慮的),我們要考慮的僅僅是當sqlarchemy把sql語句以及參數傳過來的時候我們該怎么做,以及當sql語句執行后如何對結果進行轉換

 

所需注意的問題

1 sql語句以及參數傳過來的時候我們該怎么做:

  1.1 對參數進行轉義,防止sql注入

2 執行完sql語句后對結果如何處理:

  2.1 我們知道python的基礎sql模塊會對結果進行處理,比如說把NUll轉換為None,把數據庫中的date字段轉換為python的date對象等等

  2.2 一些不知道該怎么形容的數據:

    當我們查詢時,獲取的數據對應字段的元信息

    當我們update或者delete等操作時需要獲取影響了多少行

    當我們插入數據后,如果主鍵是自增字段,我們一般(可以說在sqlarchemy中這是必須)需要獲取該記錄的主鍵值   

     實際上就是支持 python DB API 

3 sqlalchemy增加代碼,使其支持我們修改后的jaydebeapi

 

如何解決

1.1解決方案:

  人家pymysql咋搞,我就咋搞!

  在pymysql.corsors文件中Cursor類中有一個叫做mogrify的方法,這個方法不僅對參數轉義,而且會將參數放置到sql語句中組成完整的可執行sql語句。所以偷一些代碼然后稍加修改就是這樣:

#!/usr/bin/env python
# -*- coding: utf-8 -*-
from functools import partial
from pymysql.converters import escape_item, escape_string
import sys


PY2 = sys.version_info[0] == 2

if PY2:
    import __builtin__
    range_type = xrange
    text_type = unicode
    long_type = long
    str_type = basestring
    unichr = __builtin__.unichr
else:
    range_type = range
    text_type = str
    long_type = int
    str_type = str
    unichr = chr


def _ensure_bytes(x, encoding="utf8"):
    if isinstance(x, text_type):
        x = x.encode(encoding)
    return x


def _escape_args(args, encoding):
    ensure_bytes = partial(_ensure_bytes, encoding=encoding)

    if isinstance(args, (tuple, list)):
        if PY2:
            args = tuple(map(ensure_bytes, args))
        return tuple(escape(arg, encoding) for arg in args)
    elif isinstance(args, dict):
        if PY2:
            args = dict((ensure_bytes(key), ensure_bytes(val)) for
                        (key, val) in args.items())
        return dict((key, escape(val, encoding)) for (key, val) in args.items())


def escape(obj, charset, mapping=None):
    if isinstance(obj, str_type):
        return "'" + escape_string(obj) + "'"
    return escape_item(obj, charset, mapping=mapping)


def mogrify(query, encoding, args=None):
    if PY2:  # Use bytes on Python 2 always
        query = _ensure_bytes(query, encoding=encoding)
    if args is not None:
        # r = _escape_args(args, encoding)
        query = query % _escape_args(args, encoding)
    return query


# 調用一下mogrigy函數
# print(mogrify("select * from ll where a in %s and b = %s", "utf8", [[2, 1], 3]))
View Code

2.1解決方案:

  人家pymysql咋搞,我就咋搞!

  在pymysql.converters中有一個名為decoders的字典,這里面存放了mysql字段與python對象的轉換關系!大概是這樣

def _convert_second_fraction(s):
    if not s:
        return 0
    # Pad zeros to ensure the fraction length in microseconds
    s = s.ljust(6, '0')
    return int(s[:6])

DATETIME_RE = re.compile(r"(\d{1,4})-(\d{1,2})-(\d{1,2})[T ](\d{1,2}):(\d{1,2}):(\d{1,2})(?:.(\d{1,6}))?")


def convert_datetime(obj):
    """Returns a DATETIME or TIMESTAMP column value as a datetime object:

      >>> datetime_or_None('2007-02-25 23:06:20')
      datetime.datetime(2007, 2, 25, 23, 6, 20)
      >>> datetime_or_None('2007-02-25T23:06:20')
      datetime.datetime(2007, 2, 25, 23, 6, 20)

    Illegal values are returned as None:

      >>> datetime_or_None('2007-02-31T23:06:20') is None
      True
      >>> datetime_or_None('0000-00-00 00:00:00') is None
      True

    """
    if not PY2 and isinstance(obj, (bytes, bytearray)):
        obj = obj.decode('ascii')

    m = DATETIME_RE.match(obj)
    if not m:
        return convert_date(obj)

    try:
        groups = list(m.groups())
        groups[-1] = _convert_second_fraction(groups[-1])
        return datetime.datetime(*[ int(x) for x in groups ])
    except ValueError:
        return convert_date(obj)

TIMEDELTA_RE = re.compile(r"(-)?(\d{1,3}):(\d{1,2}):(\d{1,2})(?:.(\d{1,6}))?")


def convert_timedelta(obj):
    """Returns a TIME column as a timedelta object:

      >>> timedelta_or_None('25:06:17')
      datetime.timedelta(1, 3977)
      >>> timedelta_or_None('-25:06:17')
      datetime.timedelta(-2, 83177)

    Illegal values are returned as None:

      >>> timedelta_or_None('random crap') is None
      True

    Note that MySQL always returns TIME columns as (+|-)HH:MM:SS, but
    can accept values as (+|-)DD HH:MM:SS. The latter format will not
    be parsed correctly by this function.
    """
    if not PY2 and isinstance(obj, (bytes, bytearray)):
        obj = obj.decode('ascii')

    m = TIMEDELTA_RE.match(obj)
    if not m:
        return None

    try:
        groups = list(m.groups())
        groups[-1] = _convert_second_fraction(groups[-1])
        negate = -1 if groups[0] else 1
        hours, minutes, seconds, microseconds = groups[1:]

        tdelta = datetime.timedelta(
            hours = int(hours),
            minutes = int(minutes),
            seconds = int(seconds),
            microseconds = int(microseconds)
            ) * negate
        return tdelta
    except ValueError:
        return None

TIME_RE = re.compile(r"(\d{1,2}):(\d{1,2}):(\d{1,2})(?:.(\d{1,6}))?")


def convert_time(obj):
    """Returns a TIME column as a time object:

      >>> time_or_None('15:06:17')
      datetime.time(15, 6, 17)

    Illegal values are returned as None:

      >>> time_or_None('-25:06:17') is None
      True
      >>> time_or_None('random crap') is None
      True

    Note that MySQL always returns TIME columns as (+|-)HH:MM:SS, but
    can accept values as (+|-)DD HH:MM:SS. The latter format will not
    be parsed correctly by this function.

    Also note that MySQL's TIME column corresponds more closely to
    Python's timedelta and not time. However if you want TIME columns
    to be treated as time-of-day and not a time offset, then you can
    use set this function as the converter for FIELD_TYPE.TIME.
    """
    if not PY2 and isinstance(obj, (bytes, bytearray)):
        obj = obj.decode('ascii')

    m = TIME_RE.match(obj)
    if not m:
        return None

    try:
        groups = list(m.groups())
        groups[-1] = _convert_second_fraction(groups[-1])
        hours, minutes, seconds, microseconds = groups
        return datetime.time(hour=int(hours), minute=int(minutes),
                             second=int(seconds), microsecond=int(microseconds))
    except ValueError:
        return None


def convert_date(obj):
    """Returns a DATE column as a date object:

      >>> date_or_None('2007-02-26')
      datetime.date(2007, 2, 26)

    Illegal values are returned as None:

      >>> date_or_None('2007-02-31') is None
      True
      >>> date_or_None('0000-00-00') is None
      True

    """
    if not PY2 and isinstance(obj, (bytes, bytearray)):
        obj = obj.decode('ascii')
    try:
        return datetime.date(*[ int(x) for x in obj.split('-', 2) ])
    except ValueError:
        return None


def convert_mysql_timestamp(timestamp):
    """Convert a MySQL TIMESTAMP to a Timestamp object.

    MySQL >= 4.1 returns TIMESTAMP in the same format as DATETIME:

      >>> mysql_timestamp_converter('2007-02-25 22:32:17')
      datetime.datetime(2007, 2, 25, 22, 32, 17)

    MySQL < 4.1 uses a big string of numbers:

      >>> mysql_timestamp_converter('20070225223217')
      datetime.datetime(2007, 2, 25, 22, 32, 17)

    Illegal values are returned as None:

      >>> mysql_timestamp_converter('2007-02-31 22:32:17') is None
      True
      >>> mysql_timestamp_converter('00000000000000') is None
      True

    """
    if not PY2 and isinstance(timestamp, (bytes, bytearray)):
        timestamp = timestamp.decode('ascii')
    if timestamp[4] == '-':
        return convert_datetime(timestamp)
    timestamp += "0"*(14-len(timestamp)) # padding
    year, month, day, hour, minute, second = \
        int(timestamp[:4]), int(timestamp[4:6]), int(timestamp[6:8]), \
        int(timestamp[8:10]), int(timestamp[10:12]), int(timestamp[12:14])
    try:
        return datetime.datetime(year, month, day, hour, minute, second)
    except ValueError:
        return None

def convert_set(s):
    if isinstance(s, (bytes, bytearray)):
        return set(s.split(b","))
    return set(s.split(","))


def through(x):
    return x


#def convert_bit(b):
#    b = "\x00" * (8 - len(b)) + b # pad w/ zeroes
#    return struct.unpack(">Q", b)[0]
#
#     the snippet above is right, but MySQLdb doesn't process bits,
#     so we shouldn't either
convert_bit = through


def convert_characters(connection, field, data):
    field_charset = charset_by_id(field.charsetnr).name
    encoding = charset_to_encoding(field_charset)
    if field.flags & FLAG.SET:
        return convert_set(data.decode(encoding))
    if field.flags & FLAG.BINARY:
        return data

    if connection.use_unicode:
        data = data.decode(encoding)
    elif connection.charset != field_charset:
        data = data.decode(encoding)
        data = data.encode(connection.encoding)
    return data

encoders = {
    bool: escape_bool,
    int: escape_int,
    long_type: escape_int,
    float: escape_float,
    str: escape_str,
    text_type: escape_unicode,
    tuple: escape_sequence,
    list: escape_sequence,
    set: escape_sequence,
    frozenset: escape_sequence,
    dict: escape_dict,
    bytearray: escape_bytes,
    type(None): escape_None,
    datetime.date: escape_date,
    datetime.datetime: escape_datetime,
    datetime.timedelta: escape_timedelta,
    datetime.time: escape_time,
    time.struct_time: escape_struct_time,
    Decimal: escape_object,
}

if not PY2 or JYTHON or IRONPYTHON:
    encoders[bytes] = escape_bytes

decoders = {
    FIELD_TYPE.BIT: convert_bit,
    FIELD_TYPE.TINY: int,
    FIELD_TYPE.SHORT: int,
    FIELD_TYPE.LONG: int,
    FIELD_TYPE.FLOAT: float,
    FIELD_TYPE.DOUBLE: float,
    FIELD_TYPE.LONGLONG: int,
    FIELD_TYPE.INT24: int,
    FIELD_TYPE.YEAR: int,
    FIELD_TYPE.TIMESTAMP: convert_mysql_timestamp,
    FIELD_TYPE.DATETIME: convert_datetime,
    FIELD_TYPE.TIME: convert_timedelta,
    FIELD_TYPE.DATE: convert_date,
    FIELD_TYPE.SET: convert_set,
    FIELD_TYPE.BLOB: through,
    FIELD_TYPE.TINY_BLOB: through,
    FIELD_TYPE.MEDIUM_BLOB: through,
    FIELD_TYPE.LONG_BLOB: through,
    FIELD_TYPE.STRING: through,
    FIELD_TYPE.VAR_STRING: through,
    FIELD_TYPE.VARCHAR: through,
    FIELD_TYPE.DECIMAL: Decimal,
    FIELD_TYPE.NEWDECIMAL: Decimal,
}
原始代碼

  而在jaydebeapi中也有一些相似的代碼:

def _to_datetime(rs, col):
    java_val = rs.getTimestamp(col)
    if not java_val:
        return
    d = datetime.datetime.strptime(str(java_val)[:19], "%Y-%m-%d %H:%M:%S")
    d = d.replace(microsecond=int(str(java_val.getNanos())[:6]))
    return str(d)

def _to_time(rs, col):
    java_val = rs.getTime(col)
    if not java_val:
        return
    return str(java_val)

def _to_date(rs, col):
    java_val = rs.getDate(col)
    if not java_val:
        return
    # The following code requires Python 3.3+ on dates before year 1900.
    # d = datetime.datetime.strptime(str(java_val)[:10], "%Y-%m-%d")
    # return d.strftime("%Y-%m-%d")
    # Workaround / simpler soltution (see
    # https://github.com/baztian/jaydebeapi/issues/18):
    return str(java_val)[:10]

def _to_binary(rs, col):
    java_val = rs.getObject(col)
    if java_val is None:
        return
    return str(java_val)

def _java_to_py(java_method):
    def to_py(rs, col):
        java_val = rs.getObject(col)
        if java_val is None:
            return
        if PY2 and isinstance(java_val, (string_type, int, long, float, bool)):
            return java_val
        elif isinstance(java_val, (string_type, int, float, bool)):
            return java_val
        return getattr(java_val, java_method)()
    return to_py

_to_double = _java_to_py('doubleValue')

_to_int = _java_to_py('intValue')

_to_boolean = _java_to_py('booleanValue')


_DEFAULT_CONVERTERS = {
    # see
    # http://download.oracle.com/javase/8/docs/api/java/sql/Types.html
    # for possible keys
    'TIMESTAMP': _to_datetime,
    'TIME': _to_time,
    'DATE': _to_date,
    'BINARY': _to_binary,
    'DECIMAL': _to_double,
    'NUMERIC': _to_double,
    'DOUBLE': _to_double,
    'FLOAT': _to_double,
    'TINYINT': _to_int,
    'INTEGER': _to_int,
    'SMALLINT': _to_int,
    'BOOLEAN': _to_boolean,
    'BIT': _to_boolean
}
原始代碼

  然后我們稍微修改一下即可。 

2.2解決方案

  在jaydebeapi中的Cursor類中,有一個屬性叫做description這個屬性,通過他我們就能獲取查詢時表的字段的元信息

  在jaydebeapi中的Cursor類中,是有rowcount這個屬性的,他表示當我們進行插入更新刪除操作時受影響的行數。

  而在pymysql的cursors文件中的Cursor類中的_do_get_result方法中不僅僅有受影響的行數rowcount,還有lastrowid這個屬性,他表示當我們插入數據且對應主鍵是自增字段時,最后一條數據的主鍵值。但是在jaydebeapi中是沒有的,而這個屬性在sqlalchemy中恰恰是需要的,所以我們要為jaydebeapi的Cursor類加上這個屬性。代碼如下:

class Cursor(object):

    lastrowid = None
    rowcount = -1
    _meta = None
    _prep = None
    _rs = None
    _description = None
...此處省略部分不相關代碼...
def execute(self, operation, parameters=None): if self._connection._closed: raise Error() if not parameters: parameters = () self._close_last() self._prep = self._connection.jconn.prepareStatement(operation) self._set_stmt_parms(self._prep, parameters) try: is_rs = self._prep.execute() # print is_rs except: _handle_sql_exception() # print(dir(self._prep)) # 如果是查詢的話 is_rs就是1 if is_rs: self._rs = self._prep.getResultSet() self._meta = self._rs.getMetaData() self.rowcount = -1 self.lastrowid = None # 插入/修改/刪除時 is_rs都為0 else: self.rowcount = self._prep.getUpdateCount() self.lastrowid = int(self._prep.lastInsertID)

注意:上面的代碼中紅色的代碼是我新增的

3解決方案

    sqlarchemy中底層數據庫連接模塊都放在dialects這個包中,這個包里面有多個包分別是mysql oracle等數據庫的基本數據庫連接類,因為公司只使用mysql數據庫,所以僅僅做了mysql的jdbc擴展,就放到了mysql包中。

大體介紹一下我們將要修改的或者用到的類:

  MySQLDialect

    位置:sqlarchemy.dialects.mysql.base 

    描述:它是一個提供了對mysql數據庫的連接、語句的執行等操作的基類,所以我們需要新寫一個jdbcdialect類並繼承它,然后重寫某些方法。

    為什么會用到:這個就不用多說了

  ExecutionContext

    位置:sqlarchemy.engine.interface

    描述:通過這個東西我們可以獲取當前游標的執行環境,比如說本次sql語句的執行影響了多少行,我們剛插入的一行的自增主鍵值是多少。他也負責把我們所寫的python ORM語句轉換為可以被底層數據庫模塊比如pymysql可以執行的東西。

創建dialect類:

我們知道使用sqlalchemy時首先需要創建一個engine,engine的第一個參數是一個URL,就像這樣:mysql+pymysql://user:password@host:port/db?charset=utf8

  這段URL主要配置了三項:

    配置1 首先聲明了我們要連接mysql數據庫

    配置2 然后配置了底層連接數據庫的dialect(這個單詞翻譯過來叫方言,就好比同是漢語(連接mysql),我們可以說山東話(pymysql)也可以說湖南話(mysqldb))模塊是pymysql

    配置3 配置了用戶名,密碼,主機地址,端口,數據庫名等信息

  通過查看代碼我們可以看到:

    上面中的配置1實際上就是說接下來要在 sqlalchemy.dialects.mysql包中獲取提供數據庫操作等方法的class了。

    配置2實際上就是說 配置1想要找的的class我定義在了sqlalcehmy.dialects.mysql.pymysql中

    配置3會作為URL類包裝解析,然后作為參數傳入dialect實例的create_connect_args方法,以獲取數據庫連接參數。

然后創建engine時還可以指定許多額外的參數,比如說連接池的配置等,這里面有幾個我們需要注意的參數:

  假如我們沒有指定module(數據庫連接底層模塊),默認會調用dialect類的類方法dbapi

  假如我們沒有指定creator(與數據庫建立連接的方法,一般是個函數)這個參數的話默認建立連接時會調用dialect實例的connect方法,並把create_connect_args返回的連接參數傳入。

  當我們第一次與數據庫建立連接時,會調用dialect實例的initialize方法,這個方法會做一系列操作,比如說獲取當前數據庫的版本信息:dialect實例的_get_server_version_info方法;獲取當前isolation級別:dialect實例的get_isolation_level方法

然后就很簡單了:在sqlalchemy中找到sqlalchemy.dialects.mysql這個目錄,然后新建一個名叫jaydebeapi的文件,並找到該目錄下的pymysql文件,你會看到:

from .mysqldb import MySQLDialect_mysqldb
from ...util import langhelpers, py3k


class MySQLDialect_pymysql(MySQLDialect_mysqldb):
    driver = 'pymysql'

    description_encoding = None

    # generally, these two values should be both True
    # or both False.   PyMySQL unicode tests pass all the way back
    # to 0.4 either way.  See [ticket:3337]
    supports_unicode_statements = True
    supports_unicode_binds = True

    def __init__(self, server_side_cursors=False, **kwargs):
        super(MySQLDialect_pymysql, self).__init__(**kwargs)
        self.server_side_cursors = server_side_cursors

    @langhelpers.memoized_property
    def supports_server_side_cursors(self):
        try:
            cursors = __import__('pymysql.cursors').cursors
            self._sscursor = cursors.SSCursor
            return True
        except (ImportError, AttributeError):
            return False

    @classmethod
    def dbapi(cls):
        return __import__('pymysql')

    if py3k:
        def _extract_error_code(self, exception):
            if isinstance(exception.args[0], Exception):
                exception = exception.args[0]
            return exception.args[0]

dialect = MySQLDialect_pymysql
sqlalchemy.dialects.mysql.pymysql源碼

就這一個類,我們只需要繼承這個類並重寫某些方法就是了。就像這樣:

#!/usr/bin/env python
# -*- coding: utf-8 -*-
import re
from .pymysql import MySQLDialect_mysqldb


class MySQLDialect_jaydebeapi(MySQLDialect_mysqldb):
    driver = 'jaydebeapi'

    @classmethod
    def dbapi(cls):
        return __import__('jaydebeapi')

    def connect(self, *cargs, **cparams):
        # get_jdbc_conn這個方法就自己寫吧,實際上就是用jaydebeapi生成一個連接,但需要注意,連接的autocommit要設置為False
        return get_jdbc_conn(self.dbapi, **cparams)

    def _get_server_version_info(self, connection):
        dbapi_con = connection.connection
        cursor = dbapi_con.cursor()
        cursor.execute("select version()")
        version = str(cursor.fetchone()[0])
        cursor.close()
        version_list = []
        r = re.compile(r'[.\-]')
        for n in r.split(version):
            try:
                version_list.append(int(n))
            except ValueError:
                version_list.append(n)
        return tuple(version_list)

    def _detect_charset(self, connection):
        """Sniff out the character set in use for connection results."""

        try:
            # note: the SQL here would be
            # "SHOW VARIABLES LIKE 'character_set%%'"
            # print dir(connection.connection)
            cset_name = connection.connection.character_set_name
        except AttributeError:
            return 'utf8'
        else:
            return cset_name()

 

個人在修改源碼中獲取的知識點

點1:

  com.mysql.jdbc.exceptions.MySQLNonTransientConnectionException: Can’t call rollback when autocommit=true

  1. 當開啟autocommit=true時,回滾沒有意義,無論成功/失敗都已經已經將事務提交
  2. autocommit=false,我們需要運行conn.commit()執行事務, 如果失敗則需要conn.rollback()對事務進行回滾;

點2:

   嘗試連接mysql時報錯:Unknown system variable 'transaction_isolation'

  這是因為我的MySQLDialect_jaydebeapi類中的_get_server_version_info方法返回寫死為5.7.21版本,而在mysql的Mysqldialect類的get_isolation_level中,會判斷如果版本大於等於5.7.20的話執行SELECT @@transaction_isolation,反之會執行SELECT @@tx_isolation。

  於是看了看自己的mysql版本是5.7.11 ,遂改變版本號。


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM