MyBatis多租戶隔離插件開發


在SASS的大潮流下,相信依然存在很多使用一個數據庫為多個租戶提供服務的場景,這個情況下一般是多個租戶共用同一套表通過sql語句級別來隔離不同租戶的資源,比如設置一個租戶標識字段,每次查詢的時候在后面附加一個篩選條件:TenantId=xxx。這樣能低代價、簡單地實現多租戶服務,但是每次執行sql的時候需要附加字段隔離,否則會出現數據錯亂。

此隔離過程應該自動標識完成,所以我今天借助於Mybatis的插件機制來完成一個多租戶sql隔離插件。

一、設計需求

1、首先,我們需要一種方案來識別哪些表需要使用多租戶隔離,並且確定多租戶隔離字段名稱。

2、然后攔截mybatis執行過程中的prepare方法,通過改寫加入多租戶隔離條件,然后替換為我們新的sql。

3、尋找一種方法能多層次的智能的為識別到的數據表添加condition,畢竟CRUD過程都會存在子查詢,並且不會丟失原有的where條件。

二、設計思路

對於需求1,我們可以定義一個條件字段決策器,用來決策某個表是否需要添加多租戶過濾條件,比如定義一個接口:ITableFieldConditionDecision

/**
 * 表字段條件決策器
 * 用於決策某個表是否需要添加某個字段過濾條件
 *
 * @author liushuishang@gmail.com
 * @date 2017/12/23 15:49
 **/
public interface ITableFieldConditionDecision {

    /**
     * 條件字段是否運行null值
     * @return
     */
    boolean isAllowNullValue();
    /**
     * 判決某個表是否需要添加某個字段過濾
     *
     * @param tableName   表名稱
     * @param fieldName   字段名稱
     * @return
     */
    boolean adjudge(String tableName, String fieldName);

}

然后在使用插件的地方填寫必要的參數來初始化決策器

<!--多租戶隔離插件-->
                <bean class="com.smartdata360.smartfx.dao.plugin.MultiTenantPlugin">
                    <property name="properties">
                        <value>
                            <!--當前數據庫方言-->
                            dialect=postgresql
                            <!--多租戶隔離字段名稱-->
                            tenantIdField=domain
                            <!--需要隔離的表名稱java正則表達式-->
                            tablePattern=uam_*
                            <!--需要隔離的表名稱,逗號分隔-->
                            tableSet=uam_user,uam_role
                        </value>
                    </property>
                </bean>

對於需求2,我們開發一個Mybatis的攔截器:MultiTenantPlugin。抽取出將要預編譯的sql語句,加工后再替換,然后Mybatis最終執行的是我們加工過的sql語句。

/**
 * 多租戶數據隔離插件
 *
 * @author liushuishang@gmail.com
 * @date 2017/12/21 11:58
 **/
@Intercepts({
        @Signature(type = StatementHandler.class,
                method = "prepare",
                args = {Connection.class})})
public class MultiTenantPlugin extends BasePlugin

對於需求3,我使用阿里Druid的sql parser模塊來實現sql解析和condition附加。其大致過程如下:

(1)把sql解析成一顆AST,基本每個部分都會有一個對象與之對應。

(2)遍歷AST,獲取select、query和SQLExpr,抽取出表名稱和別名,交給決策器判斷是否需要添加多租戶隔離條件。如果需要添加,則擴展原有condition加上多租戶篩選條件;否則不做處理

(3)把修改后的AST重新轉成sql語句

image

執行結果:

image

三、代碼參考

import com.alibaba.druid.sql.SQLUtils;
import com.alibaba.druid.sql.ast.SQLStatement;
import com.smartdata360.smartfx.dao.extension.MultiTenantContent;
import com.smartdata360.smartfx.dao.sqlparser.ITableFieldConditionDecision;
import com.smartdata360.smartfx.dao.sqlparser.SqlConditionHelper;
import org.apache.commons.lang3.StringUtils;
import org.apache.ibatis.executor.statement.StatementHandler;
import org.apache.ibatis.mapping.BoundSql;
import org.apache.ibatis.plugin.Intercepts;
import org.apache.ibatis.plugin.Invocation;
import org.apache.ibatis.plugin.Signature;
import org.apache.ibatis.reflection.MetaObject;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.sql.Connection;
import java.util.*;
import java.util.regex.Pattern;

/**
 * 多租戶數據隔離插件
 *
 * @author liushuishang@gmail.com
 * @date 2017/12/21 11:58
 **/
@Intercepts({
        @Signature(type = StatementHandler.class,
                method = "prepare",
                args = {Connection.class})})
public class MultiTenantPlugin extends BasePlugin {

    private final Logger logger = LoggerFactory.getLogger(MultiTenantPlugin.class);

    /**
     * 當前數據庫的方言
     */
    private String dialect;
    /**
     * 多租戶字段名稱
     */
    private String tenantIdField;

    /**
     * 需要識別多租戶字段的表名稱的正則表達式
     */
    private Pattern tablePattern;

    /**
     * 需要識別多租戶字段的表名稱列表
     */
    private Set<String> tableSet;

    private SqlConditionHelper conditionHelper;


    @Override
    public Object intercept(Invocation invocation) throws Throwable {
        String tenantId = MultiTenantContent.getCurrentTenantId();
        //租戶id為空時不做處理
        if (StringUtils.isBlank(tenantId)) {
            return invocation.proceed();
        }
        StatementHandler statementHandler = (StatementHandler) invocation.getTarget();
        BoundSql boundSql = statementHandler.getBoundSql();
        String newSql = addTenantCondition(boundSql.getSql(), tenantId);
        MetaObject boundSqlMeta = getMetaObject(boundSql);
        //把新sql設置到boundSql
        boundSqlMeta.setValue("sql", newSql);

        return invocation.proceed();
    }

    @Override
    public void setProperties(Properties properties) {
        dialect = properties.getProperty("dialect");
        if (StringUtils.isBlank(dialect))
            throw new IllegalArgumentException("MultiTenantPlugin need dialect property value");
        tenantIdField = properties.getProperty("tenantIdField");
        if (StringUtils.isBlank(tenantIdField))
            throw new IllegalArgumentException("MultiTenantPlugin need tenantIdField property value");

        String tableRegex = properties.getProperty("tableRegex");
        if (!StringUtils.isBlank(tableRegex))
            tablePattern = Pattern.compile(tableRegex);

        String tableNames = properties.getProperty("tableNames");
        if (!StringUtils.isBlank(tableNames)) {
            tableSet = new HashSet<String>(Arrays.asList(StringUtils.split(tableNames)));
        }
        if (tablePattern == null || tableSet == null)
            throw new IllegalArgumentException("MultiTenantPlugin tableRegex and tableNames must have one");

        /**
         * 多租戶條件字段決策器
         */
        ITableFieldConditionDecision conditionDecision = new ITableFieldConditionDecision() {
            @Override
            public boolean isAllowNullValue() {
                return false;
            }
            @Override
            public boolean adjudge(String tableName, String fieldName) {
                if (tableRegex != null && tableRegex.matches(tableName)) return true;
                if (tableSet != null && tableSet.contains(tableName)) return true;
                return false;
            }
        };
        conditionHelper = new SqlConditionHelper(conditionDecision);
    }


    /**
     * 給sql語句where添加租戶id過濾條件
     *
     * @param sql      要添加過濾條件的sql語句
     * @param tenantId 當前的租戶id
     * @return 添加條件后的sql語句
     */
    private String addTenantCondition(String sql, String tenantId) {
        if (StringUtils.isBlank(sql) || StringUtils.isBlank(tenantIdField)) return sql;
        List<SQLStatement> statementList = SQLUtils.parseStatements(sql, dialect);
        if (statementList == null || statementList.size() == 0) return sql;

        SQLStatement sqlStatement = statementList.get(0);
        conditionHelper.addStatementCondition(sqlStatement, tenantIdField, tenantId);
        return SQLUtils.toSQLString(statementList, dialect);
    }

}
import com.alibaba.druid.sql.SQLUtils;
import com.alibaba.druid.sql.ast.SQLExpr;
import com.alibaba.druid.sql.ast.SQLStatement;
import com.alibaba.druid.sql.ast.expr.*;
import com.alibaba.druid.sql.ast.statement.*;
import com.alibaba.druid.util.JdbcConstants;
import org.apache.commons.lang3.NotImplementedException;
import org.apache.commons.lang3.StringUtils;

import java.util.List;

/**
 * sql語句where條件處理輔助類
 *
 * @author liushuishang@gmail.com
 * @date 2017/12/21 15:05
 **/
public class SqlConditionHelper {

    private ITableFieldConditionDecision conditionDecision;

    public SqlConditionHelper(ITableFieldConditionDecision conditionDecision) {
        this.conditionDecision = conditionDecision;
    }

    /**
     * 為sql'語句添加指定where條件
     *
     * @param sqlStatement
     * @param fieldName
     * @param fieldValue
     */
    public void addStatementCondition(SQLStatement sqlStatement, String fieldName, String fieldValue) {
        if (sqlStatement instanceof SQLSelectStatement) {
            SQLSelectQueryBlock queryObject = (SQLSelectQueryBlock) ((SQLSelectStatement) sqlStatement).getSelect().getQuery();
            addSelectStatementCondition(queryObject, queryObject.getFrom(), fieldName, fieldValue);
        } else if (sqlStatement instanceof SQLUpdateStatement) {
            SQLUpdateStatement updateStatement = (SQLUpdateStatement) sqlStatement;
            addUpdateStatementCondition(updateStatement, fieldName, fieldValue);
        } else if (sqlStatement instanceof SQLDeleteStatement) {
            SQLDeleteStatement deleteStatement = (SQLDeleteStatement) sqlStatement;
            addDeleteStatementCondition(deleteStatement, fieldName, fieldValue);
        } else if (sqlStatement instanceof SQLInsertStatement) {
            SQLInsertStatement insertStatement = (SQLInsertStatement) sqlStatement;
            addInsertStatementCondition(insertStatement, fieldName, fieldValue);
        }
    }

    /**
     * 為insert語句添加where條件
     *
     * @param insertStatement
     * @param fieldName
     * @param fieldValue
     */
    private void addInsertStatementCondition(SQLInsertStatement insertStatement, String fieldName, String fieldValue) {
        if (insertStatement != null) {
            SQLInsertInto sqlInsertInto = insertStatement;
            SQLSelect sqlSelect = sqlInsertInto.getQuery();
            if (sqlSelect != null) {
                SQLSelectQueryBlock selectQueryBlock = (SQLSelectQueryBlock) sqlSelect.getQuery();
                addSelectStatementCondition(selectQueryBlock, selectQueryBlock.getFrom(), fieldName, fieldValue);
            }
        }
    }


    /**
     * 為delete語句添加where條件
     *
     * @param deleteStatement
     * @param fieldName
     * @param fieldValue
     */
    private void addDeleteStatementCondition(SQLDeleteStatement deleteStatement, String fieldName, String fieldValue) {
        SQLExpr where = deleteStatement.getWhere();
        //添加子查詢中的where條件
        addSQLExprCondition(where, fieldName, fieldValue);

        SQLExpr newCondition = newEqualityCondition(deleteStatement.getTableName().getSimpleName(),
                deleteStatement.getTableSource().getAlias(), fieldName, fieldValue, where);
        deleteStatement.setWhere(newCondition);

    }

    /**
     * where中添加指定篩選條件
     *
     * @param where      源where條件
     * @param fieldName
     * @param fieldValue
     */
    private void addSQLExprCondition(SQLExpr where, String fieldName, String fieldValue) {
        if (where instanceof SQLInSubQueryExpr) {
            SQLInSubQueryExpr inWhere = (SQLInSubQueryExpr) where;
            SQLSelect subSelectObject = inWhere.getSubQuery();
            SQLSelectQueryBlock subQueryObject = (SQLSelectQueryBlock) subSelectObject.getQuery();
            addSelectStatementCondition(subQueryObject, subQueryObject.getFrom(), fieldName, fieldValue);
        } else if (where instanceof SQLBinaryOpExpr) {
            SQLBinaryOpExpr opExpr = (SQLBinaryOpExpr) where;
            SQLExpr left = opExpr.getLeft();
            SQLExpr right = opExpr.getRight();
            addSQLExprCondition(left, fieldName, fieldValue);
            addSQLExprCondition(right, fieldName, fieldValue);
        } else if (where instanceof SQLQueryExpr) {
            SQLSelectQueryBlock selectQueryBlock = (SQLSelectQueryBlock) (((SQLQueryExpr) where).getSubQuery()).getQuery();
            addSelectStatementCondition(selectQueryBlock, selectQueryBlock.getFrom(), fieldName, fieldValue);
        }
    }

    /**
     * 為update語句添加where條件
     *
     * @param updateStatement
     * @param fieldName
     * @param fieldValue
     */
    private void addUpdateStatementCondition(SQLUpdateStatement updateStatement, String fieldName, String fieldValue) {
        SQLExpr where = updateStatement.getWhere();
        //添加子查詢中的where條件
        addSQLExprCondition(where, fieldName, fieldValue);
        SQLExpr newCondition = newEqualityCondition(updateStatement.getTableName().getSimpleName(),
                updateStatement.getTableSource().getAlias(), fieldName, fieldValue, where);
        updateStatement.setWhere(newCondition);
    }

    /**
     * 給一個查詢對象添加一個where條件
     *
     * @param queryObject
     * @param fieldName
     * @param fieldValue
     */
    private void addSelectStatementCondition(SQLSelectQueryBlock queryObject, SQLTableSource from, String fieldName, String fieldValue) {
        if (StringUtils.isBlank(fieldName) || from == null || queryObject == null) return;

        SQLExpr originCondition = queryObject.getWhere();
        if (from instanceof SQLExprTableSource) {
            String tableName = ((SQLIdentifierExpr) ((SQLExprTableSource) from).getExpr()).getName();
            String alias = from.getAlias();
            SQLExpr newCondition = newEqualityCondition(tableName, alias, fieldName, fieldValue, originCondition);
            queryObject.setWhere(newCondition);
        } else if (from instanceof SQLJoinTableSource) {
            SQLJoinTableSource joinObject = (SQLJoinTableSource) from;
            SQLTableSource left = joinObject.getLeft();
            SQLTableSource right = joinObject.getRight();

            addSelectStatementCondition(queryObject, left, fieldName, fieldValue);
            addSelectStatementCondition(queryObject, right, fieldName, fieldValue);

        } else if (from instanceof SQLSubqueryTableSource) {
            SQLSelect subSelectObject = ((SQLSubqueryTableSource) from).getSelect();
            SQLSelectQueryBlock subQueryObject = (SQLSelectQueryBlock) subSelectObject.getQuery();
            addSelectStatementCondition(subQueryObject, subQueryObject.getFrom(), fieldName, fieldValue);
        } else {
            throw new NotImplementedException("未處理的異常");
        }
    }

    /**
     * 根據原來的condition創建一個新的condition
     *
     * @param tableName       表名稱
     * @param tableAlias      表別名
     * @param fieldName
     * @param fieldValue
     * @param originCondition
     * @return
     */
    private SQLExpr newEqualityCondition(String tableName, String tableAlias, String fieldName, String fieldValue, SQLExpr originCondition) {
        //如果不需要設置條件
        if (!conditionDecision.adjudge(tableName, fieldName)) return originCondition;
        //如果條件字段不允許為空
        if (fieldValue == null && !conditionDecision.isAllowNullValue()) return originCondition;

        String filedName = StringUtils.isBlank(tableAlias) ? fieldName : tableAlias + "." + fieldName;
        SQLExpr condition = new SQLBinaryOpExpr(new SQLIdentifierExpr(filedName), new SQLCharExpr(fieldValue), SQLBinaryOperator.Equality);
        return SQLUtils.buildCondition(SQLBinaryOperator.BooleanAnd, condition, false, originCondition);
    }


    public static void main(String[] args) {
//        String sql = "select * from user s  ";
//        String sql = "select * from user s where s.name='333'";
//        String sql = "select * from (select * from tab t where id = 2 and name = 'wenshao') s where s.name='333'";
//        String sql="select u.*,g.name from user u join user_group g on u.groupId=g.groupId where u.name='123'";

//        String sql = "update user set name=? where id =(select id from user s)";
//        String sql = "delete from user where id = ( select id from user s )";

//        String sql = "insert into user (id,name) select g.id,g.name from user_group g where id=1";

        String sql = "select u.*,g.name from user u join (select * from user_group g  join user_role r on g.role_code=r.code  ) g on u.groupId=g.groupId where u.name='123'";
        List<SQLStatement> statementList = SQLUtils.parseStatements(sql, JdbcConstants.POSTGRESQL);
        SQLStatement sqlStatement = statementList.get(0);
        //決策器定義
        SqlConditionHelper helper = new SqlConditionHelper(new ITableFieldConditionDecision() {
            @Override
            public boolean adjudge(String tableName, String fieldName) {
                return true;
            }

            @Override
            public boolean isAllowNullValue() {
                return false;
            }
        });
        //添加多租戶條件,domain是字段ignc,yay是篩選值
        helper.addStatementCondition(sqlStatement, "domain", "yay");
        System.out.println("源sql:" + sql);
        System.out.println("修改后sql:" + SQLUtils.toSQLString(statementList, JdbcConstants.POSTGRESQL));
    }


}

因為時間和環境限制,僅僅提供一個基礎版本,可能測試不夠充分,歡迎提出修正意見。


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM