EFCore扩展Update方法(实现 Update User SET Id = Id + 1)
前言
- EFCore在操作更新的时候往往需要先查询一遍数据,再去更新相应的字段,如果针对批量更新的话会很麻烦,效率也很低。
- 目前github上 EFCore.Extentions 项目,实现批量更新挺方便的,但是针对 Update User SET Id = Id + 1 这种操作还是没有解决
- 本文主要就是扩展自更新Update
实现原理
- 先根据IQuaryable 获取到SQL语句
private
static
readonly
TypeInfo
QueryCompilerTypeInfo
=
typeof
(
QueryCompiler
)
.
GetTypeInfo
();
private static readonly FieldInfo QueryCompilerField = typeof ( EntityQueryProvider ) . GetTypeInfo () . DeclaredFields . First ( x => x . Name == "_queryCompiler" );
private static readonly FieldInfo QueryModelGeneratorField = QueryCompilerTypeInfo . DeclaredFields . First ( x => x . Name == "_queryModelGenerator" );
private static readonly FieldInfo DataBaseField = QueryCompilerTypeInfo . DeclaredFields . Single ( x => x . Name == "_database" );
private static readonly PropertyInfo DatabaseDependenciesField = typeof ( Database ) . GetTypeInfo () . DeclaredProperties . Single ( x => x . Name == "Dependencies" );
/// < summary >
/// 将query 转化为sql语句
/// </ summary >
/// < typeparam name = " TEntity " ></ typeparam >
/// < param name = " query " ></ param >
/// < returns ></ returns >
internal static string ToSql < TEntity >( this IQueryable < TEntity > query ) where TEntity : class
{
var queryCompiler = ( QueryCompiler ) QueryCompilerField . GetValue ( query . Provider );
var modelGenerator = ( QueryModelGenerator ) QueryModelGeneratorField . GetValue ( queryCompiler );
var queryModel = modelGenerator . ParseQuery ( query . Expression );
var database = ( IDatabase ) DataBaseField . GetValue ( queryCompiler );
var databaseDependencies = ( DatabaseDependencies ) DatabaseDependenciesField . GetValue ( database );
var queryCompilationContext = databaseDependencies . QueryCompilationContextFactory . Create ( false );
var modelVisitor = ( RelationalQueryModelVisitor ) queryCompilationContext . CreateQueryModelVisitor ();
modelVisitor . CreateQueryExecutor < TEntity >( queryModel );
string sql = modelVisitor . Queries . First () . ToString ();
return sql ;
}
private static readonly FieldInfo QueryCompilerField = typeof ( EntityQueryProvider ) . GetTypeInfo () . DeclaredFields . First ( x => x . Name == "_queryCompiler" );
private static readonly FieldInfo QueryModelGeneratorField = QueryCompilerTypeInfo . DeclaredFields . First ( x => x . Name == "_queryModelGenerator" );
private static readonly FieldInfo DataBaseField = QueryCompilerTypeInfo . DeclaredFields . Single ( x => x . Name == "_database" );
private static readonly PropertyInfo DatabaseDependenciesField = typeof ( Database ) . GetTypeInfo () . DeclaredProperties . Single ( x => x . Name == "Dependencies" );
/// < summary >
/// 将query 转化为sql语句
/// </ summary >
/// < typeparam name = " TEntity " ></ typeparam >
/// < param name = " query " ></ param >
/// < returns ></ returns >
internal static string ToSql < TEntity >( this IQueryable < TEntity > query ) where TEntity : class
{
var queryCompiler = ( QueryCompiler ) QueryCompilerField . GetValue ( query . Provider );
var modelGenerator = ( QueryModelGenerator ) QueryModelGeneratorField . GetValue ( queryCompiler );
var queryModel = modelGenerator . ParseQuery ( query . Expression );
var database = ( IDatabase ) DataBaseField . GetValue ( queryCompiler );
var databaseDependencies = ( DatabaseDependencies ) DatabaseDependenciesField . GetValue ( database );
var queryCompilationContext = databaseDependencies . QueryCompilationContextFactory . Create ( false );
var modelVisitor = ( RelationalQueryModelVisitor ) queryCompilationContext . CreateQueryModelVisitor ();
modelVisitor . CreateQueryExecutor < TEntity >( queryModel );
string sql = modelVisitor . Queries . First () . ToString ();
return sql ;
}
- 把获取的查询语句的,From之前的语句砍掉,然后拼接
public
static
(
string
,
string
)
GetBatchSql
<
T
>(
IQueryable
<
T
>
query
)
where
T
:
class
,
new
()
{
string sqlQuery = query . ToSql ();
string tableAlias = sqlQuery . Substring ( 8 , sqlQuery . IndexOf ( "]" ) - 8 );
int indexFROM = sqlQuery . IndexOf ( Environment . NewLine );
string sql = sqlQuery . Substring ( indexFROM , sqlQuery . Length - indexFROM );
sql = sql . Contains ( "{" ) ? sql . Replace ( "{" , "{{" ) : sql ; // Curly brackets have to escaped:
sql = sql . Contains ( "}" ) ? sql . Replace ( "}" , "}}" ) : sql ; // https://github.com/aspnet/EntityFrameworkCore/issues/8820
return ( sql , tableAlias );
}
{
string sqlQuery = query . ToSql ();
string tableAlias = sqlQuery . Substring ( 8 , sqlQuery . IndexOf ( "]" ) - 8 );
int indexFROM = sqlQuery . IndexOf ( Environment . NewLine );
string sql = sqlQuery . Substring ( indexFROM , sqlQuery . Length - indexFROM );
sql = sql . Contains ( "{" ) ? sql . Replace ( "{" , "{{" ) : sql ; // Curly brackets have to escaped:
sql = sql . Contains ( "}" ) ? sql . Replace ( "}" , "}}" ) : sql ; // https://github.com/aspnet/EntityFrameworkCore/issues/8820
return ( sql , tableAlias );
}
- 根据传入的表达式Expression<Func<T,bool> 生成(Update [a] SET) 之后要更新的部分.列如:Expression<Func<T,bool> expression=a=>a.Id == a.Id + 1 生成[a].[Id]=[a].[Id]+parm_0 pram_0=1
通过分析Expression的节点[NodeType]来生成相应的操作符,递归拼接sql语句和参数
public
static
void
CreateUpdateBody
(
string
Param
,
Expression
expression
,
ref
StringBuilder
sb
,
ref
List
<
SqlParameter
>
sp
)
{
if ( expression is BinaryExpression binaryExpression )
{
CreateUpdateBody ( Param , binaryExpression . Left , ref sb , ref sp );
switch ( binaryExpression . NodeType )
{
case ExpressionType . Add :
sb . Append ( " +" );
break ;
case ExpressionType . Divide :
sb . Append ( " /" );
break ;
case ExpressionType . Multiply :
sb . Append ( " *" );
break ;
case ExpressionType . Subtract :
sb . Append ( " -" );
break ;
case ExpressionType . And :
sb . Append ( " ," );
break ;
case ExpressionType . AndAlso :
sb . Append ( " ," );
break ;
case ExpressionType . Or :
sb . Append ( " ," );
break ;
case ExpressionType . OrElse :
sb . Append ( " ," );
break ;
case ExpressionType . Equal :
sb . Append ( " =" );
break ;
default : break ;
}
CreateUpdateBody ( Param , binaryExpression . Right , ref sb , ref sp );
}
if ( expression is ConstantExpression constantExpression )
{
var parmName = $"param_ { sp . Count } " ;
sp . Add ( new SqlParameter ( parmName , constantExpression . Value ));
sb . Append ( $" @ { parmName } " );
}
if ( expression is MemberExpression memberExpression )
{
sb . Append ( $" { Param } .[ { memberExpression . Member . Name } ]" );
}
}
{
if ( expression is BinaryExpression binaryExpression )
{
CreateUpdateBody ( Param , binaryExpression . Left , ref sb , ref sp );
switch ( binaryExpression . NodeType )
{
case ExpressionType . Add :
sb . Append ( " +" );
break ;
case ExpressionType . Divide :
sb . Append ( " /" );
break ;
case ExpressionType . Multiply :
sb . Append ( " *" );
break ;
case ExpressionType . Subtract :
sb . Append ( " -" );
break ;
case ExpressionType . And :
sb . Append ( " ," );
break ;
case ExpressionType . AndAlso :
sb . Append ( " ," );
break ;
case ExpressionType . Or :
sb . Append ( " ," );
break ;
case ExpressionType . OrElse :
sb . Append ( " ," );
break ;
case ExpressionType . Equal :
sb . Append ( " =" );
break ;
default : break ;
}
CreateUpdateBody ( Param , binaryExpression . Right , ref sb , ref sp );
}
if ( expression is ConstantExpression constantExpression )
{
var parmName = $"param_ { sp . Count } " ;
sp . Add ( new SqlParameter ( parmName , constantExpression . Value ));
sb . Append ( $" @ { parmName } " );
}
if ( expression is MemberExpression memberExpression )
{
sb . Append ( $" { Param } .[ { memberExpression . Member . Name } ]" );
}
}
-
最后执行生成SQL语句,详情请看源码github
-
调用
static
void
Main
(
string
[]
args
)
{
using ( var context = new TestContext ())
{
var list = context . User . Select < UserModel >() . ToList ();
var user1 = context . User . AsNoTracking () . FirstOrDefault ( x => x . Id == 2 );
Console . WriteLine ( $"-----------Before Update --------------------" );
Console . WriteLine ( $" { user1 . Id } : { user1 . Name } : { user1 . RoleId } " );
context . User . Where ( x => x . Id == 2 ) . RestValue ( x => x . Name == ( x . Name + " Add Bob" ) && x . RoleId == ( x . RoleId + 1 ));
var user2 = context . User . AsNoTracking () . FirstOrDefault ( x => x . Id == 2 );
Console . WriteLine ( $"-----------After Update --------------------" );
Console . WriteLine ( $" { user2 . Id } : { user2 . Name } : { user2 . RoleId } " );
}
Console . WriteLine ( $"------------结束--------------------" );
Console . ReadLine ();
}
{
using ( var context = new TestContext ())
{
var list = context . User . Select < UserModel >() . ToList ();
var user1 = context . User . AsNoTracking () . FirstOrDefault ( x => x . Id == 2 );
Console . WriteLine ( $"-----------Before Update --------------------" );
Console . WriteLine ( $" { user1 . Id } : { user1 . Name } : { user1 . RoleId } " );
context . User . Where ( x => x . Id == 2 ) . RestValue ( x => x . Name == ( x . Name + " Add Bob" ) && x . RoleId == ( x . RoleId + 1 ));
var user2 = context . User . AsNoTracking () . FirstOrDefault ( x => x . Id == 2 );
Console . WriteLine ( $"-----------After Update --------------------" );
Console . WriteLine ( $" { user2 . Id } : { user2 . Name } : { user2 . RoleId } " );
}
Console . WriteLine ( $"------------结束--------------------" );
Console . ReadLine ();
}
- 生成的SQL如下
UPDATE [x] SET x.[Name] =x.[Name] + @param_0 ,x.[RoleId] =x.[RoleId] + @param_1
FROM [User] AS [x]
WHERE [x].[Id] = 2
