在SASS的大潮流下,相信依然存在很多使用一個數據庫為多個租戶提供服務的場景,這個情況下一般是多個租戶共用同一套表通過sql語句級別來隔離不同租戶的資源,比如設置一個租戶標識字段,每次查詢的時候在后面附加一個篩選條件:TenantId=xxx。這樣能低代價、簡單地實現多租戶服務,但是每次執行sql的時候需要附加字段隔離,否則會出現數據錯亂。
此隔離過程應該自動標識完成,所以我今天借助於Mybatis的插件機制來完成一個多租戶sql隔離插件。
一、設計需求
1、首先,我們需要一種方案來識別哪些表需要使用多租戶隔離,並且確定多租戶隔離字段名稱。
2、然后攔截mybatis執行過程中的prepare方法,通過改寫加入多租戶隔離條件,然后替換為我們新的sql。
3、尋找一種方法能多層次的智能的為識別到的數據表添加condition,畢竟CRUD過程都會存在子查詢,並且不會丟失原有的where條件。
二、設計思路
對於需求1,我們可以定義一個條件字段決策器,用來決策某個表是否需要添加多租戶過濾條件,比如定義一個接口:ITableFieldConditionDecision
/** * 表字段條件決策器 * 用於決策某個表是否需要添加某個字段過濾條件 * * @author liushuishang@gmail.com * @date 2017/12/23 15:49 **/ public interface ITableFieldConditionDecision { /** * 條件字段是否運行null值 * @return */ boolean isAllowNullValue(); /** * 判決某個表是否需要添加某個字段過濾 * * @param tableName 表名稱 * @param fieldName 字段名稱 * @return */ boolean adjudge(String tableName, String fieldName); }
然后在使用插件的地方填寫必要的參數來初始化決策器
<!--多租戶隔離插件--> <bean class="com.smartdata360.smartfx.dao.plugin.MultiTenantPlugin"> <property name="properties"> <value> <!--當前數據庫方言--> dialect=postgresql <!--多租戶隔離字段名稱--> tenantIdField=domain <!--需要隔離的表名稱java正則表達式--> tablePattern=uam_* <!--需要隔離的表名稱,逗號分隔--> tableSet=uam_user,uam_role </value> </property> </bean>
對於需求2,我們開發一個Mybatis的攔截器:MultiTenantPlugin。抽取出將要預編譯的sql語句,加工后再替換,然后Mybatis最終執行的是我們加工過的sql語句。
/** * 多租戶數據隔離插件 * * @author liushuishang@gmail.com * @date 2017/12/21 11:58 **/ @Intercepts({ @Signature(type = StatementHandler.class, method = "prepare", args = {Connection.class})}) public class MultiTenantPlugin extends BasePlugin
對於需求3,我使用阿里Druid的sql parser模塊來實現sql解析和condition附加。其大致過程如下:
(1)把sql解析成一顆AST,基本每個部分都會有一個對象與之對應。
(2)遍歷AST,獲取select、query和SQLExpr,抽取出表名稱和別名,交給決策器判斷是否需要添加多租戶隔離條件。如果需要添加,則擴展原有condition加上多租戶篩選條件;否則不做處理
(3)把修改后的AST重新轉成sql語句
執行結果:
三、代碼參考
import com.alibaba.druid.sql.SQLUtils; import com.alibaba.druid.sql.ast.SQLStatement; import com.smartdata360.smartfx.dao.extension.MultiTenantContent; import com.smartdata360.smartfx.dao.sqlparser.ITableFieldConditionDecision; import com.smartdata360.smartfx.dao.sqlparser.SqlConditionHelper; import org.apache.commons.lang3.StringUtils; import org.apache.ibatis.executor.statement.StatementHandler; import org.apache.ibatis.mapping.BoundSql; import org.apache.ibatis.plugin.Intercepts; import org.apache.ibatis.plugin.Invocation; import org.apache.ibatis.plugin.Signature; import org.apache.ibatis.reflection.MetaObject; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import java.sql.Connection; import java.util.*; import java.util.regex.Pattern; /** * 多租戶數據隔離插件 * * @author liushuishang@gmail.com * @date 2017/12/21 11:58 **/ @Intercepts({ @Signature(type = StatementHandler.class, method = "prepare", args = {Connection.class})}) public class MultiTenantPlugin extends BasePlugin { private final Logger logger = LoggerFactory.getLogger(MultiTenantPlugin.class); /** * 當前數據庫的方言 */ private String dialect; /** * 多租戶字段名稱 */ private String tenantIdField; /** * 需要識別多租戶字段的表名稱的正則表達式 */ private Pattern tablePattern; /** * 需要識別多租戶字段的表名稱列表 */ private Set<String> tableSet; private SqlConditionHelper conditionHelper; @Override public Object intercept(Invocation invocation) throws Throwable { String tenantId = MultiTenantContent.getCurrentTenantId(); //租戶id為空時不做處理 if (StringUtils.isBlank(tenantId)) { return invocation.proceed(); } StatementHandler statementHandler = (StatementHandler) invocation.getTarget(); BoundSql boundSql = statementHandler.getBoundSql(); String newSql = addTenantCondition(boundSql.getSql(), tenantId); MetaObject boundSqlMeta = getMetaObject(boundSql); //把新sql設置到boundSql boundSqlMeta.setValue("sql", newSql); return invocation.proceed(); } @Override public void setProperties(Properties properties) { dialect = properties.getProperty("dialect"); if (StringUtils.isBlank(dialect)) throw new IllegalArgumentException("MultiTenantPlugin need dialect property value"); tenantIdField = properties.getProperty("tenantIdField"); if (StringUtils.isBlank(tenantIdField)) throw new IllegalArgumentException("MultiTenantPlugin need tenantIdField property value"); String tableRegex = properties.getProperty("tableRegex"); if (!StringUtils.isBlank(tableRegex)) tablePattern = Pattern.compile(tableRegex); String tableNames = properties.getProperty("tableNames"); if (!StringUtils.isBlank(tableNames)) { tableSet = new HashSet<String>(Arrays.asList(StringUtils.split(tableNames))); } if (tablePattern == null || tableSet == null) throw new IllegalArgumentException("MultiTenantPlugin tableRegex and tableNames must have one"); /** * 多租戶條件字段決策器 */ ITableFieldConditionDecision conditionDecision = new ITableFieldConditionDecision() { @Override public boolean isAllowNullValue() { return false; } @Override public boolean adjudge(String tableName, String fieldName) { if (tableRegex != null && tableRegex.matches(tableName)) return true; if (tableSet != null && tableSet.contains(tableName)) return true; return false; } }; conditionHelper = new SqlConditionHelper(conditionDecision); } /** * 給sql語句where添加租戶id過濾條件 * * @param sql 要添加過濾條件的sql語句 * @param tenantId 當前的租戶id * @return 添加條件后的sql語句 */ private String addTenantCondition(String sql, String tenantId) { if (StringUtils.isBlank(sql) || StringUtils.isBlank(tenantIdField)) return sql; List<SQLStatement> statementList = SQLUtils.parseStatements(sql, dialect); if (statementList == null || statementList.size() == 0) return sql; SQLStatement sqlStatement = statementList.get(0); conditionHelper.addStatementCondition(sqlStatement, tenantIdField, tenantId); return SQLUtils.toSQLString(statementList, dialect); } }
import com.alibaba.druid.sql.SQLUtils; import com.alibaba.druid.sql.ast.SQLExpr; import com.alibaba.druid.sql.ast.SQLStatement; import com.alibaba.druid.sql.ast.expr.*; import com.alibaba.druid.sql.ast.statement.*; import com.alibaba.druid.util.JdbcConstants; import org.apache.commons.lang3.NotImplementedException; import org.apache.commons.lang3.StringUtils; import java.util.List; /** * sql語句where條件處理輔助類 * * @author liushuishang@gmail.com * @date 2017/12/21 15:05 **/ public class SqlConditionHelper { private ITableFieldConditionDecision conditionDecision; public SqlConditionHelper(ITableFieldConditionDecision conditionDecision) { this.conditionDecision = conditionDecision; } /** * 為sql'語句添加指定where條件 * * @param sqlStatement * @param fieldName * @param fieldValue */ public void addStatementCondition(SQLStatement sqlStatement, String fieldName, String fieldValue) { if (sqlStatement instanceof SQLSelectStatement) { SQLSelectQueryBlock queryObject = (SQLSelectQueryBlock) ((SQLSelectStatement) sqlStatement).getSelect().getQuery(); addSelectStatementCondition(queryObject, queryObject.getFrom(), fieldName, fieldValue); } else if (sqlStatement instanceof SQLUpdateStatement) { SQLUpdateStatement updateStatement = (SQLUpdateStatement) sqlStatement; addUpdateStatementCondition(updateStatement, fieldName, fieldValue); } else if (sqlStatement instanceof SQLDeleteStatement) { SQLDeleteStatement deleteStatement = (SQLDeleteStatement) sqlStatement; addDeleteStatementCondition(deleteStatement, fieldName, fieldValue); } else if (sqlStatement instanceof SQLInsertStatement) { SQLInsertStatement insertStatement = (SQLInsertStatement) sqlStatement; addInsertStatementCondition(insertStatement, fieldName, fieldValue); } } /** * 為insert語句添加where條件 * * @param insertStatement * @param fieldName * @param fieldValue */ private void addInsertStatementCondition(SQLInsertStatement insertStatement, String fieldName, String fieldValue) { if (insertStatement != null) { SQLInsertInto sqlInsertInto = insertStatement; SQLSelect sqlSelect = sqlInsertInto.getQuery(); if (sqlSelect != null) { SQLSelectQueryBlock selectQueryBlock = (SQLSelectQueryBlock) sqlSelect.getQuery(); addSelectStatementCondition(selectQueryBlock, selectQueryBlock.getFrom(), fieldName, fieldValue); } } } /** * 為delete語句添加where條件 * * @param deleteStatement * @param fieldName * @param fieldValue */ private void addDeleteStatementCondition(SQLDeleteStatement deleteStatement, String fieldName, String fieldValue) { SQLExpr where = deleteStatement.getWhere(); //添加子查詢中的where條件 addSQLExprCondition(where, fieldName, fieldValue); SQLExpr newCondition = newEqualityCondition(deleteStatement.getTableName().getSimpleName(), deleteStatement.getTableSource().getAlias(), fieldName, fieldValue, where); deleteStatement.setWhere(newCondition); } /** * where中添加指定篩選條件 * * @param where 源where條件 * @param fieldName * @param fieldValue */ private void addSQLExprCondition(SQLExpr where, String fieldName, String fieldValue) { if (where instanceof SQLInSubQueryExpr) { SQLInSubQueryExpr inWhere = (SQLInSubQueryExpr) where; SQLSelect subSelectObject = inWhere.getSubQuery(); SQLSelectQueryBlock subQueryObject = (SQLSelectQueryBlock) subSelectObject.getQuery(); addSelectStatementCondition(subQueryObject, subQueryObject.getFrom(), fieldName, fieldValue); } else if (where instanceof SQLBinaryOpExpr) { SQLBinaryOpExpr opExpr = (SQLBinaryOpExpr) where; SQLExpr left = opExpr.getLeft(); SQLExpr right = opExpr.getRight(); addSQLExprCondition(left, fieldName, fieldValue); addSQLExprCondition(right, fieldName, fieldValue); } else if (where instanceof SQLQueryExpr) { SQLSelectQueryBlock selectQueryBlock = (SQLSelectQueryBlock) (((SQLQueryExpr) where).getSubQuery()).getQuery(); addSelectStatementCondition(selectQueryBlock, selectQueryBlock.getFrom(), fieldName, fieldValue); } } /** * 為update語句添加where條件 * * @param updateStatement * @param fieldName * @param fieldValue */ private void addUpdateStatementCondition(SQLUpdateStatement updateStatement, String fieldName, String fieldValue) { SQLExpr where = updateStatement.getWhere(); //添加子查詢中的where條件 addSQLExprCondition(where, fieldName, fieldValue); SQLExpr newCondition = newEqualityCondition(updateStatement.getTableName().getSimpleName(), updateStatement.getTableSource().getAlias(), fieldName, fieldValue, where); updateStatement.setWhere(newCondition); } /** * 給一個查詢對象添加一個where條件 * * @param queryObject * @param fieldName * @param fieldValue */ private void addSelectStatementCondition(SQLSelectQueryBlock queryObject, SQLTableSource from, String fieldName, String fieldValue) { if (StringUtils.isBlank(fieldName) || from == null || queryObject == null) return; SQLExpr originCondition = queryObject.getWhere(); if (from instanceof SQLExprTableSource) { String tableName = ((SQLIdentifierExpr) ((SQLExprTableSource) from).getExpr()).getName(); String alias = from.getAlias(); SQLExpr newCondition = newEqualityCondition(tableName, alias, fieldName, fieldValue, originCondition); queryObject.setWhere(newCondition); } else if (from instanceof SQLJoinTableSource) { SQLJoinTableSource joinObject = (SQLJoinTableSource) from; SQLTableSource left = joinObject.getLeft(); SQLTableSource right = joinObject.getRight(); addSelectStatementCondition(queryObject, left, fieldName, fieldValue); addSelectStatementCondition(queryObject, right, fieldName, fieldValue); } else if (from instanceof SQLSubqueryTableSource) { SQLSelect subSelectObject = ((SQLSubqueryTableSource) from).getSelect(); SQLSelectQueryBlock subQueryObject = (SQLSelectQueryBlock) subSelectObject.getQuery(); addSelectStatementCondition(subQueryObject, subQueryObject.getFrom(), fieldName, fieldValue); } else { throw new NotImplementedException("未處理的異常"); } } /** * 根據原來的condition創建一個新的condition * * @param tableName 表名稱 * @param tableAlias 表別名 * @param fieldName * @param fieldValue * @param originCondition * @return */ private SQLExpr newEqualityCondition(String tableName, String tableAlias, String fieldName, String fieldValue, SQLExpr originCondition) { //如果不需要設置條件 if (!conditionDecision.adjudge(tableName, fieldName)) return originCondition; //如果條件字段不允許為空 if (fieldValue == null && !conditionDecision.isAllowNullValue()) return originCondition; String filedName = StringUtils.isBlank(tableAlias) ? fieldName : tableAlias + "." + fieldName; SQLExpr condition = new SQLBinaryOpExpr(new SQLIdentifierExpr(filedName), new SQLCharExpr(fieldValue), SQLBinaryOperator.Equality); return SQLUtils.buildCondition(SQLBinaryOperator.BooleanAnd, condition, false, originCondition); } public static void main(String[] args) { // String sql = "select * from user s "; // String sql = "select * from user s where s.name='333'"; // String sql = "select * from (select * from tab t where id = 2 and name = 'wenshao') s where s.name='333'"; // String sql="select u.*,g.name from user u join user_group g on u.groupId=g.groupId where u.name='123'"; // String sql = "update user set name=? where id =(select id from user s)"; // String sql = "delete from user where id = ( select id from user s )"; // String sql = "insert into user (id,name) select g.id,g.name from user_group g where id=1"; String sql = "select u.*,g.name from user u join (select * from user_group g join user_role r on g.role_code=r.code ) g on u.groupId=g.groupId where u.name='123'"; List<SQLStatement> statementList = SQLUtils.parseStatements(sql, JdbcConstants.POSTGRESQL); SQLStatement sqlStatement = statementList.get(0); //決策器定義 SqlConditionHelper helper = new SqlConditionHelper(new ITableFieldConditionDecision() { @Override public boolean adjudge(String tableName, String fieldName) { return true; } @Override public boolean isAllowNullValue() { return false; } }); //添加多租戶條件,domain是字段ignc,yay是篩選值 helper.addStatementCondition(sqlStatement, "domain", "yay"); System.out.println("源sql:" + sql); System.out.println("修改后sql:" + SQLUtils.toSQLString(statementList, JdbcConstants.POSTGRESQL)); } }
因為時間和環境限制,僅僅提供一個基礎版本,可能測試不夠充分,歡迎提出修正意見。