2020/01/28, ASP.NET Core 3.1, VS2019, Microsoft.EntityFrameworkCore.Relational 3.1.1
摘要:基於ASP.NET Core 3.1 WebApi搭建后端多層網站架構【4-工作單元和倉儲設計】
使用泛型倉儲(Repository)和工作單元(UnitOfWork)模式封裝數據訪問層基礎的增刪改查等方法
關於本章節的工作單元模式:
泛型倉儲封裝了通用的增刪改查方法,由工作單元統一管理倉儲以保證數據庫上下文一致性。
要獲取倉儲,都從工作單元中獲取,通過倉儲改動數據庫后,由工作單元進行提交。
代碼參考Arch/UnitOfWork的設計,大部分都是參考他的,然后做了一些中文注釋,去除了分布式多庫支持
添加包引用
向MS.UnitOfWork
項目添加對Microsoft.EntityFrameworkCore.Relational
包的引用:
<ItemGroup>
<PackageReference Include="Microsoft.EntityFrameworkCore.Relational" Version="3.1.1" />
</ItemGroup>
分頁處理封裝
在MS.UnitOfWork
項目中添加Collections文件夾,在該文件夾下添加IPagedList.cs
、PagedList.cs
、IEnumerablePagedListExtensions.cs
、IQueryablePageListExtensions.cs
類。
IPagedList.cs
using System.Collections.Generic;
namespace MS.UnitOfWork.Collections
{
/// <summary>
/// 提供任何類型的分頁接口
/// </summary>
/// <typeparam name="T">需要分頁的數據類型</typeparam>
public interface IPagedList<T>
{
/// <summary>
/// 起始頁 值
/// </summary>
int IndexFrom { get; }
/// <summary>
/// 當前頁 值
/// </summary>
int PageIndex { get; }
/// <summary>
/// 每頁大小
/// </summary>
int PageSize { get; }
/// <summary>
/// 數據總數
/// </summary>
int TotalCount { get; }
/// <summary>
/// 總頁數
/// </summary>
int TotalPages { get; }
/// <summary>
/// 當前頁數據
/// </summary>
IList<T> Items { get; }
/// <summary>
/// 是否有上一頁
/// </summary>
bool HasPreviousPage { get; }
/// <summary>
/// 是否有下一頁
/// </summary>
bool HasNextPage { get; }
}
}
PagedList.cs
using System;
using System.Collections.Generic;
using System.Linq;
namespace MS.UnitOfWork.Collections
{
/// <summary>
/// 提供數據的分頁,<see cref="IPagedList{T}"/>的默認實現
/// </summary>
/// <typeparam name="T"></typeparam>
public class PagedList<T> : IPagedList<T>
{
/// <summary>
/// 當前頁 值
/// </summary>
public int PageIndex { get; set; }
/// <summary>
/// 每頁大小
/// </summary>
public int PageSize { get; set; }
/// <summary>
/// 數據總數
/// </summary>
public int TotalCount { get; set; }
/// <summary>
/// 總頁數
/// </summary>
public int TotalPages { get; set; }
/// <summary>
/// 起始頁 值
/// </summary>
public int IndexFrom { get; set; }
/// <summary>
/// 當前頁數據
/// </summary>
public IList<T> Items { get; set; }
/// <summary>
/// 是否有上一頁
/// </summary>
public bool HasPreviousPage => PageIndex - IndexFrom > 0;
/// <summary>
/// 是否有下一頁
/// </summary>
public bool HasNextPage => PageIndex - IndexFrom + 1 < TotalPages;
/// <summary>
/// 初始化實例
/// </summary>
/// <param name="source">The source.</param>
/// <param name="pageIndex">The index of the page.</param>
/// <param name="pageSize">The size of the page.</param>
/// <param name="indexFrom">The index from.</param>
internal PagedList(IEnumerable<T> source, int pageIndex, int pageSize, int indexFrom)
{
if (indexFrom > pageIndex)
{
throw new ArgumentException($"indexFrom: {indexFrom} > pageIndex: {pageIndex},起始頁必須小於等於當前頁");
}
if (source is IQueryable<T> querable)
{
PageIndex = pageIndex;
PageSize = pageSize;
IndexFrom = indexFrom;
TotalCount = querable.Count();
TotalPages = (int)Math.Ceiling(TotalCount / (double)PageSize);
Items = querable.Skip((PageIndex - IndexFrom) * PageSize).Take(PageSize).ToList();
}
else
{
PageIndex = pageIndex;
PageSize = pageSize;
IndexFrom = indexFrom;
TotalCount = source.Count();
TotalPages = (int)Math.Ceiling(TotalCount / (double)PageSize);
Items = source.Skip((PageIndex - IndexFrom) * PageSize).Take(PageSize).ToList();
}
}
/// <summary>
/// Initializes a new instance of the <see cref="PagedList{T}" /> class.
/// </summary>
internal PagedList() => Items = new T[0];
}
/// <summary>
/// 提供數據的分頁,並支持數據類型轉換
/// </summary>
/// <typeparam name="TSource">數據源類型</typeparam>
/// <typeparam name="TResult">輸出數據類型</typeparam>
internal class PagedList<TSource, TResult> : IPagedList<TResult>
{
/// <summary>
/// 當前頁 值
/// </summary>
public int PageIndex { get; set; }
/// <summary>
/// 每頁大小
/// </summary>
public int PageSize { get; set; }
/// <summary>
/// 數據總數
/// </summary>
public int TotalCount { get; set; }
/// <summary>
/// 總頁數
/// </summary>
public int TotalPages { get; set; }
/// <summary>
/// 起始頁 值
/// </summary>
public int IndexFrom { get; set; }
/// <summary>
/// 當前頁數據
/// </summary>
public IList<TResult> Items { get; set; }
/// <summary>
/// 是否有上一頁
/// </summary>
public bool HasPreviousPage => PageIndex - IndexFrom > 0;
/// <summary>
/// 是否有下一頁
/// </summary>
public bool HasNextPage => PageIndex - IndexFrom + 1 < TotalPages;
/// <summary>
/// 初始化實例
/// </summary>
/// <param name="source">The source.</param>
/// <param name="converter">The converter.</param>
/// <param name="pageIndex">The index of the page.</param>
/// <param name="pageSize">The size of the page.</param>
/// <param name="indexFrom">The index from.</param>
public PagedList(IEnumerable<TSource> source, Func<IEnumerable<TSource>, IEnumerable<TResult>> converter, int pageIndex, int pageSize, int indexFrom)
{
if (indexFrom > pageIndex)
{
throw new ArgumentException($"indexFrom: {indexFrom} > pageIndex: {pageIndex},起始頁必須小於等於當前頁");
}
if (source is IQueryable<TSource> querable)
{
PageIndex = pageIndex;
PageSize = pageSize;
IndexFrom = indexFrom;
TotalCount = querable.Count();
TotalPages = (int)Math.Ceiling(TotalCount / (double)PageSize);
var items = querable.Skip((PageIndex - IndexFrom) * PageSize).Take(PageSize).ToArray();
Items = new List<TResult>(converter(items));
}
else
{
PageIndex = pageIndex;
PageSize = pageSize;
IndexFrom = indexFrom;
TotalCount = source.Count();
TotalPages = (int)Math.Ceiling(TotalCount / (double)PageSize);
var items = source.Skip((PageIndex - IndexFrom) * PageSize).Take(PageSize).ToArray();
Items = new List<TResult>(converter(items));
}
}
/// <summary>
/// Initializes a new instance of the <see cref="PagedList{TSource, TResult}" /> class.
/// </summary>
/// <param name="source">The source.</param>
/// <param name="converter">The converter.</param>
public PagedList(IPagedList<TSource> source, Func<IEnumerable<TSource>, IEnumerable<TResult>> converter)
{
PageIndex = source.PageIndex;
PageSize = source.PageSize;
IndexFrom = source.IndexFrom;
TotalCount = source.TotalCount;
TotalPages = source.TotalPages;
Items = new List<TResult>(converter(source.Items));
}
}
/// <summary>
/// Provides some help methods for <see cref="IPagedList{T}"/> interface.
/// </summary>
public static class PagedList
{
/// <summary>
/// Creates an empty of <see cref="IPagedList{T}"/>.
/// </summary>
/// <typeparam name="T">The type for paging </typeparam>
/// <returns>An empty instance of <see cref="IPagedList{T}"/>.</returns>
public static IPagedList<T> Empty<T>() => new PagedList<T>();
/// <summary>
/// Creates a new instance of <see cref="IPagedList{TResult}"/> from source of <see cref="IPagedList{TSource}"/> instance.
/// </summary>
/// <typeparam name="TResult">The type of the result.</typeparam>
/// <typeparam name="TSource">The type of the source.</typeparam>
/// <param name="source">The source.</param>
/// <param name="converter">The converter.</param>
/// <returns>An instance of <see cref="IPagedList{TResult}"/>.</returns>
public static IPagedList<TResult> From<TResult, TSource>(IPagedList<TSource> source, Func<IEnumerable<TSource>, IEnumerable<TResult>> converter) => new PagedList<TSource, TResult>(source, converter);
}
}
IEnumerablePagedListExtensions.cs
using System;
using System.Collections.Generic;
namespace MS.UnitOfWork.Collections
{
/// <summary>
/// 給<see cref="IEnumerable{T}"/>添加擴展方法來支持分頁
/// </summary>
public static class IEnumerablePagedListExtensions
{
/// <summary>
/// 在數據中取得固定頁的數據
/// </summary>
/// <typeparam name="T">數據類型</typeparam>
/// <param name="source">數據源</param>
/// <param name="pageIndex">當前頁</param>
/// <param name="pageSize">頁大小</param>
/// <param name="indexFrom">起始頁</param>
/// <returns></returns>
public static IPagedList<T> ToPagedList<T>(this IEnumerable<T> source, int pageIndex, int pageSize, int indexFrom = 1) => new PagedList<T>(source, pageIndex, pageSize, indexFrom);
/// <summary>
/// 在數據中取得固定頁數據,並轉換為指定數據類型
/// </summary>
/// <typeparam name="TSource">數據源類型</typeparam>
/// <typeparam name="TResult">輸出數據類型</typeparam>
/// <param name="source">數據源</param>
/// <param name="converter"></param>
/// <param name="pageIndex">當前頁</param>
/// <param name="pageSize">頁大小</param>
/// <param name="indexFrom">起始頁</param>
/// <returns></returns>
public static IPagedList<TResult> ToPagedList<TSource, TResult>(this IEnumerable<TSource> source, Func<IEnumerable<TSource>, IEnumerable<TResult>> converter, int pageIndex, int pageSize, int indexFrom = 1) => new PagedList<TSource, TResult>(source, converter, pageIndex, pageSize, indexFrom);
}
}
IQueryablePageListExtensions.cs
using Microsoft.EntityFrameworkCore;
using System;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
namespace MS.UnitOfWork.Collections
{
public static class IQueryablePageListExtensions
{
/// <summary>
/// 在數據中取得固定頁的數據(異步操作)
/// </summary>
/// <typeparam name="T">數據類型</typeparam>
/// <param name="source">數據源</param>
/// <param name="pageIndex">當前頁</param>
/// <param name="pageSize">頁大小</param>
/// <param name="indexFrom">起始頁</param>
/// <param name="cancellationToken">異步觀察參數</param>
/// <returns></returns>
public static async Task<IPagedList<T>> ToPagedListAsync<T>(this IQueryable<T> source, int pageIndex, int pageSize, int indexFrom = 1, CancellationToken cancellationToken = default(CancellationToken))
{
if (indexFrom > pageIndex)
{
throw new ArgumentException($"indexFrom: {indexFrom} > pageIndex: {pageIndex}, must indexFrom <= pageIndex");
}
var count = await source.CountAsync(cancellationToken).ConfigureAwait(false);
var items = await source.Skip((pageIndex - indexFrom) * pageSize)
.Take(pageSize).ToListAsync(cancellationToken).ConfigureAwait(false);
var pagedList = new PagedList<T>()
{
PageIndex = pageIndex,
PageSize = pageSize,
IndexFrom = indexFrom,
TotalCount = count,
Items = items,
TotalPages = (int)Math.Ceiling(count / (double)pageSize)
};
return pagedList;
}
}
}
針對IQueryable、IEnumerable類型的數據做了分頁擴展方法封裝,主要用於向數據庫獲取數據時進行分頁篩選
泛型倉儲
在MS.UnitOfWork
項目中添加Repository文件夾,在該文件夾下添加IRepository.cs
、Repository.cs
類。
IRepository.cs
using MS.UnitOfWork.Collections;
using Microsoft.EntityFrameworkCore.ChangeTracking;
using Microsoft.EntityFrameworkCore.Query;
using System;
using System.Collections.Generic;
using System.Linq;
using System.Linq.Expressions;
using System.Threading;
using System.Threading.Tasks;
namespace MS.UnitOfWork
{
/// <summary>
/// 通用倉儲接口
/// </summary>
/// <typeparam name="TEntity"></typeparam>
public interface IRepository<TEntity> where TEntity : class
{
#region GetAll
/// <summary>
///獲取所有實體
///注意性能!
/// </summary>
/// <returns>The <see cref="IQueryable{TEntity}"/>.</returns>
IQueryable<TEntity> GetAll();
/// <summary>
/// 獲取所有實體
/// </summary>
/// <param name="predicate">條件表達式</param>
/// <param name="orderBy">排序</param>
/// <param name="include">包含的導航屬性</param>
/// <param name="disableTracking">設置為true關閉追蹤查詢。默認為true</param>
/// <param name="ignoreQueryFilters">設置為true忽略全局查詢篩選過濾。默認為false</param>
/// <returns></returns>
IQueryable<TEntity> GetAll(
Expression<Func<TEntity, bool>> predicate = null,
Func<IQueryable<TEntity>, IOrderedQueryable<TEntity>> orderBy = null,
Func<IQueryable<TEntity>, IIncludableQueryable<TEntity, object>> include = null,
bool disableTracking = true,
bool ignoreQueryFilters = false);
/// <summary>
/// 獲取所有實體,必須提供篩選謂詞
/// </summary>
/// <typeparam name="TResult">輸出數據類型</typeparam>
/// <param name="selector">投影選擇器</param>
/// <param name="predicate">篩選謂詞</param>
/// <param name="orderBy">排序</param>
/// <param name="include">包含的導航屬性</param>
/// <param name="disableTracking">設置為true關閉追蹤查詢。默認為true</param>
/// <returns></returns>
IQueryable<TResult> GetAll<TResult>(
Expression<Func<TEntity, TResult>> selector,
Expression<Func<TEntity, bool>> predicate = null,
Func<IQueryable<TEntity>, IOrderedQueryable<TEntity>> orderBy = null,
Func<IQueryable<TEntity>, IIncludableQueryable<TEntity, object>> include = null,
bool disableTracking = true,
bool ignoreQueryFilters = false
) where TResult : class;
/// <summary>
/// 獲取所有實體
/// </summary>
/// <param name="predicate">條件表達式</param>
/// <param name="orderBy">排序</param>
/// <param name="include">包含的導航屬性</param>
/// <param name="disableTracking">設置為true關閉追蹤查詢。默認為true</param>
/// <param name="ignoreQueryFilters">設置為true忽略全局查詢篩選過濾。默認為false</param>
/// <returns></returns>
Task<IList<TEntity>> GetAllAsync(
Expression<Func<TEntity, bool>> predicate = null,
Func<IQueryable<TEntity>, IOrderedQueryable<TEntity>> orderBy = null,
Func<IQueryable<TEntity>, IIncludableQueryable<TEntity, object>> include = null,
bool disableTracking = true,
bool ignoreQueryFilters = false);
#endregion
#region GetPagedList
/// <summary>
/// 獲取分頁數據
/// 默認是關閉追蹤查詢的(拿到的數據默認只讀)
/// 默認開啟全局查詢篩選過濾
/// </summary>
/// <param name="predicate">條件表達式</param>
/// <param name="orderBy">排序</param>
/// <param name="include">包含的導航屬性</param>
/// <param name="pageIndex">當前頁。默認第一頁</param>
/// <param name="pageSize">頁大小。默認20筆數據</param>
/// <param name="disableTracking">設置為true關閉追蹤查詢。默認為true</param>
/// <param name="ignoreQueryFilters">設置為true忽略全局查詢篩選過濾。默認為false</param>
/// <returns></returns>
IPagedList<TEntity> GetPagedList(
Expression<Func<TEntity, bool>> predicate = null,
Func<IQueryable<TEntity>, IOrderedQueryable<TEntity>> orderBy = null,
Func<IQueryable<TEntity>, IIncludableQueryable<TEntity, object>> include = null,
int pageIndex = 1,
int pageSize = 20,
bool disableTracking = true,
bool ignoreQueryFilters = false);
/// <summary>
/// 獲取分頁數據
/// 默認是關閉追蹤查詢的(拿到的數據默認只讀)
/// 默認開啟全局查詢篩選過濾
/// </summary>
/// <param name="predicate">條件表達式</param>
/// <param name="orderBy">排序</param>
/// <param name="include">包含的導航屬性</param>
/// <param name="pageIndex">當前頁。默認第一頁</param>
/// <param name="pageSize">頁大小。默認20筆數據</param>
/// <param name="disableTracking">設置為true關閉追蹤查詢。默認為true</param>
/// <param name="ignoreQueryFilters">設置為true忽略全局查詢篩選過濾。默認為false</param>
/// <param name="cancellationToken">異步token</param>
/// <returns></returns>
Task<IPagedList<TEntity>> GetPagedListAsync(
Expression<Func<TEntity, bool>> predicate = null,
Func<IQueryable<TEntity>, IOrderedQueryable<TEntity>> orderBy = null,
Func<IQueryable<TEntity>, IIncludableQueryable<TEntity, object>> include = null,
int pageIndex = 1,
int pageSize = 20,
bool disableTracking = true,
bool ignoreQueryFilters = false,
CancellationToken cancellationToken = default);
/// <summary>
/// 獲取分頁數據
/// 默認是關閉追蹤查詢的(拿到的數據默認只讀)
/// 默認開啟全局查詢篩選過濾
/// </summary>
/// <typeparam name="TResult">輸出數據類型</typeparam>
/// <param name="selector">投影選擇器</param>
/// <param name="predicate">條件表達式</param>
/// <param name="orderBy">排序</param>
/// <param name="include">包含的導航屬性</param>
/// <param name="pageIndex">當前頁。默認第一頁</param>
/// <param name="pageSize">頁大小。默認20筆數據</param>
/// <param name="disableTracking">設置為true關閉追蹤查詢。默認為true</param>
/// <param name="ignoreQueryFilters">設置為true忽略全局查詢篩選過濾。默認為false</param>
/// <returns></returns>
IPagedList<TResult> GetPagedList<TResult>(
Expression<Func<TEntity, TResult>> selector,
Expression<Func<TEntity, bool>> predicate = null,
Func<IQueryable<TEntity>, IOrderedQueryable<TEntity>> orderBy = null,
Func<IQueryable<TEntity>, IIncludableQueryable<TEntity, object>> include = null,
int pageIndex = 1,
int pageSize = 20,
bool disableTracking = true,
bool ignoreQueryFilters = false
) where TResult : class;
/// <summary>
/// 獲取分頁數據
/// 默認是關閉追蹤查詢的(拿到的數據默認只讀)
/// 默認開啟全局查詢篩選過濾
/// </summary>
/// <typeparam name="TResult">輸出數據類型</typeparam>
/// <param name="selector">投影選擇器</param>
/// <param name="predicate">條件表達式</param>
/// <param name="orderBy">排序</param>
/// <param name="include">包含的導航屬性</param>
/// <param name="pageIndex">當前頁。默認第一頁</param>
/// <param name="pageSize">頁大小。默認20筆數據</param>
/// <param name="disableTracking">設置為true關閉追蹤查詢。默認為true</param>
/// <param name="ignoreQueryFilters">設置為true忽略全局查詢篩選過濾。默認為false</param>
/// <param name="cancellationToken">異步token</param>
/// <returns></returns>
Task<IPagedList<TResult>> GetPagedListAsync<TResult>(
Expression<Func<TEntity, TResult>> selector,
Expression<Func<TEntity, bool>> predicate = null,
Func<IQueryable<TEntity>, IOrderedQueryable<TEntity>> orderBy = null,
Func<IQueryable<TEntity>, IIncludableQueryable<TEntity, object>> include = null,
int pageIndex = 1,
int pageSize = 20,
bool disableTracking = true,
bool ignoreQueryFilters = false,
CancellationToken cancellationToken = default) where TResult : class;
#endregion
#region GetFirstOrDefault
/// <summary>
/// 獲取滿足條件的序列中的第一個元素
/// 如果沒有元素滿足條件,則返回默認值
/// 默認是關閉追蹤查詢的(拿到的數據默認只讀)
/// 默認開啟全局查詢篩選過濾
/// </summary>
/// <param name="predicate">條件表達式</param>
/// <param name="orderBy">排序</param>
/// <param name="include">包含的導航屬性</param>
/// <param name="disableTracking">設置為true關閉追蹤查詢。默認為true</param>
/// <param name="ignoreQueryFilters">設置為true忽略全局查詢篩選過濾。默認為false</param>
/// <returns></returns>
TEntity GetFirstOrDefault(
Expression<Func<TEntity, bool>> predicate = null,
Func<IQueryable<TEntity>, IOrderedQueryable<TEntity>> orderBy = null,
Func<IQueryable<TEntity>, IIncludableQueryable<TEntity, object>> include = null,
bool disableTracking = true,
bool ignoreQueryFilters = false);
/// <summary>
/// 獲取滿足條件的序列中的第一個元素
/// 如果沒有元素滿足條件,則返回默認值
/// 默認是關閉追蹤查詢的(拿到的數據默認只讀)
/// 默認開啟全局查詢篩選過濾
/// </summary>
/// <param name="predicate">條件表達式</param>
/// <param name="orderBy">排序</param>
/// <param name="include">包含的導航屬性</param>
/// <param name="disableTracking">設置為true關閉追蹤查詢。默認為true</param>
/// <param name="ignoreQueryFilters">設置為true忽略全局查詢篩選過濾。默認為false</param>
/// <param name="cancellationToken">異步token</param>
/// <returns></returns>
Task<TEntity> GetFirstOrDefaultAsync(
Expression<Func<TEntity, bool>> predicate = null,
Func<IQueryable<TEntity>, IOrderedQueryable<TEntity>> orderBy = null,
Func<IQueryable<TEntity>, IIncludableQueryable<TEntity, object>> include = null,
bool disableTracking = true,
bool ignoreQueryFilters = false,
CancellationToken cancellationToken = default);
/// <summary>
/// 獲取滿足條件的序列中的第一個元素
/// 如果沒有元素滿足條件,則返回默認值
/// 默認是關閉追蹤查詢的(拿到的數據默認只讀)
/// 默認開啟全局查詢篩選過濾
/// </summary>
/// <typeparam name="TResult">輸出數據類型</typeparam>
/// <param name="selector">投影選擇器</param>
/// <param name="predicate">條件表達式</param>
/// <param name="orderBy">排序</param>
/// <param name="include">包含的導航屬性</param>
/// <param name="disableTracking">設置為true關閉追蹤查詢。默認為true</param>
/// <param name="ignoreQueryFilters">設置為true忽略全局查詢篩選過濾。默認為false</param>
/// <returns></returns>
TResult GetFirstOrDefault<TResult>(
Expression<Func<TEntity, TResult>> selector,
Expression<Func<TEntity, bool>> predicate = null,
Func<IQueryable<TEntity>, IOrderedQueryable<TEntity>> orderBy = null,
Func<IQueryable<TEntity>, IIncludableQueryable<TEntity, object>> include = null,
bool disableTracking = true,
bool ignoreQueryFilters = false);
/// <summary>
/// 獲取滿足條件的序列中的第一個元素
/// 如果沒有元素滿足條件,則返回默認值
/// 默認是關閉追蹤查詢的(拿到的數據默認只讀)
/// 默認開啟全局查詢篩選過濾
/// </summary>
/// <typeparam name="TResult">輸出數據類型</typeparam>
/// <param name="selector">投影選擇器</param>
/// <param name="predicate">條件表達式</param>
/// <param name="orderBy">排序</param>
/// <param name="include">包含的導航屬性</param>
/// <param name="disableTracking">設置為true關閉追蹤查詢。默認為true</param>
/// <param name="ignoreQueryFilters">設置為true忽略全局查詢篩選過濾。默認為false</param>
/// <param name="cancellationToken">異步token</param>
/// <returns></returns>
Task<TResult> GetFirstOrDefaultAsync<TResult>(
Expression<Func<TEntity, TResult>> selector,
Expression<Func<TEntity, bool>> predicate = null,
Func<IQueryable<TEntity>, IOrderedQueryable<TEntity>> orderBy = null,
Func<IQueryable<TEntity>, IIncludableQueryable<TEntity, object>> include = null,
bool disableTracking = true,
bool ignoreQueryFilters = false,
CancellationToken cancellationToken = default);
#endregion
#region Find
/// <summary>
/// Finds an entity with the given primary key values. If found, is attached to the context and returned. If no entity is found, then null is returned.
/// </summary>
/// <param name="keyValues">The values of the primary key for the entity to be found.</param>
/// <returns>The found entity or null.</returns>
TEntity Find(params object[] keyValues);
/// <summary>
/// Finds an entity with the given primary key values. If found, is attached to the context and returned. If no entity is found, then null is returned.
/// </summary>
/// <param name="keyValues">The values of the primary key for the entity to be found.</param>
/// <returns>A <see cref="Task{TEntity}"/> that represents the asynchronous find operation. The task result contains the found entity or null.</returns>
ValueTask<TEntity> FindAsync(params object[] keyValues);
/// <summary>
/// Finds an entity with the given primary key values. If found, is attached to the context and returned. If no entity is found, then null is returned.
/// </summary>
/// <param name="keyValues">The values of the primary key for the entity to be found.</param>
/// <param name="cancellationToken">A <see cref="CancellationToken"/> to observe while waiting for the task to complete.</param>
/// <returns>A <see cref="Task{TEntity}"/> that represents the asynchronous find operation. The task result contains the found entity or null.</returns>
ValueTask<TEntity> FindAsync(object[] keyValues, CancellationToken cancellationToken);
#endregion
#region sql、count、exist
/// <summary>
/// 使用原生sql查詢來獲取指定數據
/// </summary>
/// <param name="sql"></param>
/// <param name="parameters"></param>
/// <returns></returns>
IQueryable<TEntity> FromSql(string sql, params object[] parameters);
/// <summary>
/// 查詢數量
/// </summary>
/// <param name="predicate"></param>
/// <returns></returns>
int Count(Expression<Func<TEntity, bool>> predicate = null);
/// <summary>
/// 查詢數量
/// </summary>
/// <param name="predicate"></param>
/// <returns></returns>
Task<int> CountAsync(Expression<Func<TEntity, bool>> predicate = null);
/// <summary>
/// 按指定條件元素是否存在
/// </summary>
/// <param name="predicate"></param>
/// <returns></returns>
bool Exists(Expression<Func<TEntity, bool>> predicate = null);
#endregion
#region Insert
/// <summary>
/// Inserts a new entity synchronously.
/// </summary>
/// <param name="entity"></param>
/// <returns></returns>
TEntity Insert(TEntity entity);
/// <summary>
/// Inserts a range of entities synchronously.
/// </summary>
/// <param name="entities">The entities to insert.</param>
void Insert(params TEntity[] entities);
/// <summary>
/// Inserts a range of entities synchronously.
/// </summary>
/// <param name="entities">The entities to insert.</param>
void Insert(IEnumerable<TEntity> entities);
/// <summary>
/// Inserts a new entity asynchronously.
/// </summary>
/// <param name="entity">The entity to insert.</param>
/// <param name="cancellationToken">A <see cref="CancellationToken"/> to observe while waiting for the task to complete.</param>
/// <returns>A <see cref="Task"/> that represents the asynchronous insert operation.</returns>
ValueTask<EntityEntry<TEntity>> InsertAsync(TEntity entity, CancellationToken cancellationToken = default);
/// <summary>
/// Inserts a range of entities asynchronously.
/// </summary>
/// <param name="entities">The entities to insert.</param>
/// <returns>A <see cref="Task"/> that represents the asynchronous insert operation.</returns>
Task InsertAsync(params TEntity[] entities);
/// <summary>
/// Inserts a range of entities asynchronously.
/// </summary>
/// <param name="entities">The entities to insert.</param>
/// <param name="cancellationToken">A <see cref="CancellationToken"/> to observe while waiting for the task to complete.</param>
/// <returns>A <see cref="Task"/> that represents the asynchronous insert operation.</returns>
Task InsertAsync(IEnumerable<TEntity> entities, CancellationToken cancellationToken = default);
#endregion
#region Update
/// <summary>
/// Updates the specified entity.
/// </summary>
/// <param name="entity">The entity.</param>
void Update(TEntity entity);
/// <summary>
/// Updates the specified entities.
/// </summary>
/// <param name="entities">The entities.</param>
void Update(params TEntity[] entities);
/// <summary>
/// Updates the specified entities.
/// </summary>
/// <param name="entities">The entities.</param>
void Update(IEnumerable<TEntity> entities);
#endregion
#region Delete
/// <summary>
/// Deletes the entity by the specified primary key.
/// </summary>
/// <param name="id">The primary key value.</param>
void Delete(object id);
/// <summary>
/// Deletes the specified entity.
/// </summary>
/// <param name="entity">The entity to delete.</param>
void Delete(TEntity entity);
/// <summary>
/// Deletes the specified entities.
/// </summary>
/// <param name="entities">The entities.</param>
void Delete(params TEntity[] entities);
/// <summary>
/// Deletes the specified entities.
/// </summary>
/// <param name="entities">The entities.</param>
void Delete(IEnumerable<TEntity> entities);
#endregion
}
}
Repository.cs
using MS.UnitOfWork.Collections;
using Microsoft.EntityFrameworkCore;
using Microsoft.EntityFrameworkCore.ChangeTracking;
using Microsoft.EntityFrameworkCore.Query;
using System;
using System.Collections.Generic;
using System.Linq;
using System.Linq.Expressions;
using System.Reflection;
using System.Threading;
using System.Threading.Tasks;
namespace MS.UnitOfWork
{
/// <summary>
/// 通用倉儲的默認實現
/// </summary>
/// <typeparam name="TEntity"></typeparam>
public class Repository<TEntity> : IRepository<TEntity> where TEntity : class
{
protected readonly DbContext _dbContext;
protected readonly DbSet<TEntity> _dbSet;
public Repository(DbContext dbContext)
{
_dbContext = dbContext ?? throw new ArgumentNullException(nameof(dbContext));
_dbSet = _dbContext.Set<TEntity>();
}
#region GetAll
public IQueryable<TEntity> GetAll() => _dbSet;
public IQueryable<TEntity> GetAll(
Expression<Func<TEntity, bool>> predicate = null,
Func<IQueryable<TEntity>, IOrderedQueryable<TEntity>> orderBy = null,
Func<IQueryable<TEntity>, IIncludableQueryable<TEntity, object>> include = null,
bool disableTracking = true,
bool ignoreQueryFilters = false)
{
IQueryable<TEntity> query = _dbSet;
if (disableTracking)
{
query = query.AsNoTracking();
}
if (include != null)
{
query = include(query);
}
if (predicate != null)
{
query = query.Where(predicate);
}
if (ignoreQueryFilters)
{
query = query.IgnoreQueryFilters();
}
if (orderBy != null)
{
return orderBy(query);
}
else
{
return query;
}
}
public IQueryable<TResult> GetAll<TResult>(
Expression<Func<TEntity, TResult>> selector,
Expression<Func<TEntity, bool>> predicate,
Func<IQueryable<TEntity>, IOrderedQueryable<TEntity>> orderBy = null,
Func<IQueryable<TEntity>, IIncludableQueryable<TEntity, object>> include = null,
bool disableTracking = true,
bool ignoreQueryFilters = false) where TResult : class
{
IQueryable<TEntity> query = _dbSet;
if (disableTracking)
{
query = query.AsNoTracking();
}
if (include != null)
{
query = include(query);
}
if (predicate != null)
{
query = query.Where(predicate);
}
if (ignoreQueryFilters)
{
query = query.IgnoreQueryFilters();
}
if (orderBy != null)
{
return orderBy(query).Select(selector);
}
else
{
return query.Select(selector);
}
}
public async Task<IList<TEntity>> GetAllAsync(Expression<Func<TEntity, bool>> predicate = null, Func<IQueryable<TEntity>, IOrderedQueryable<TEntity>> orderBy = null, Func<IQueryable<TEntity>, IIncludableQueryable<TEntity, object>> include = null, bool disableTracking = true, bool ignoreQueryFilters = false)
{
IQueryable<TEntity> query = _dbSet;
if (disableTracking)
{
query = query.AsNoTracking();
}
if (include != null)
{
query = include(query);
}
if (predicate != null)
{
query = query.Where(predicate);
}
if (ignoreQueryFilters)
{
query = query.IgnoreQueryFilters();
}
if (orderBy != null)
{
return await orderBy(query).ToListAsync();
}
else
{
return await query.ToListAsync();
}
}
#endregion
#region GetPagedList
public virtual IPagedList<TEntity> GetPagedList(
Expression<Func<TEntity, bool>> predicate = null,
Func<IQueryable<TEntity>, IOrderedQueryable<TEntity>> orderBy = null,
Func<IQueryable<TEntity>, IIncludableQueryable<TEntity, object>> include = null,
int pageIndex = 1,
int pageSize = 20,
bool disableTracking = true,
bool ignoreQueryFilters = false)
{
IQueryable<TEntity> query = _dbSet;
if (disableTracking)
{
query = query.AsNoTracking();
}
if (include != null)
{
query = include(query);
}
if (predicate != null)
{
query = query.Where(predicate);
}
if (ignoreQueryFilters)
{
query = query.IgnoreQueryFilters();
}
if (orderBy != null)
{
return orderBy(query).ToPagedList(pageIndex, pageSize);
}
else
{
return query.ToPagedList(pageIndex, pageSize);
}
}
public virtual async Task<IPagedList<TEntity>> GetPagedListAsync(
Expression<Func<TEntity, bool>> predicate = null,
Func<IQueryable<TEntity>, IOrderedQueryable<TEntity>> orderBy = null,
Func<IQueryable<TEntity>, IIncludableQueryable<TEntity, object>> include = null,
int pageIndex = 1,
int pageSize = 20,
bool disableTracking = true,
bool ignoreQueryFilters = false,
CancellationToken cancellationToken = default)
{
IQueryable<TEntity> query = _dbSet;
if (disableTracking)
{
query = query.AsNoTracking();
}
if (include != null)
{
query = include(query);
}
if (predicate != null)
{
query = query.Where(predicate);
}
if (ignoreQueryFilters)
{
query = query.IgnoreQueryFilters();
}
if (orderBy != null)
{
return await orderBy(query).ToPagedListAsync(pageIndex, pageSize, 1, cancellationToken);
}
else
{
return await query.ToPagedListAsync(pageIndex, pageSize, 1, cancellationToken);
}
}
public virtual IPagedList<TResult> GetPagedList<TResult>(
Expression<Func<TEntity, TResult>> selector,
Expression<Func<TEntity, bool>> predicate = null,
Func<IQueryable<TEntity>, IOrderedQueryable<TEntity>> orderBy = null,
Func<IQueryable<TEntity>, IIncludableQueryable<TEntity, object>> include = null,
int pageIndex = 1,
int pageSize = 20,
bool disableTracking = true,
bool ignoreQueryFilters = false)
where TResult : class
{
IQueryable<TEntity> query = _dbSet;
if (disableTracking)
{
query = query.AsNoTracking();
}
if (include != null)
{
query = include(query);
}
if (predicate != null)
{
query = query.Where(predicate);
}
if (ignoreQueryFilters)
{
query = query.IgnoreQueryFilters();
}
if (orderBy != null)
{
return orderBy(query).Select(selector).ToPagedList(pageIndex, pageSize);
}
else
{
return query.Select(selector).ToPagedList(pageIndex, pageSize);
}
}
public virtual async Task<IPagedList<TResult>> GetPagedListAsync<TResult>(
Expression<Func<TEntity, TResult>> selector,
Expression<Func<TEntity, bool>> predicate = null,
Func<IQueryable<TEntity>, IOrderedQueryable<TEntity>> orderBy = null,
Func<IQueryable<TEntity>, IIncludableQueryable<TEntity, object>> include = null,
int pageIndex = 1,
int pageSize = 20,
bool disableTracking = true,
bool ignoreQueryFilters = false,
CancellationToken cancellationToken = default)
where TResult : class
{
IQueryable<TEntity> query = _dbSet;
if (disableTracking)
{
query = query.AsNoTracking();
}
if (include != null)
{
query = include(query);
}
if (predicate != null)
{
query = query.Where(predicate);
}
if (ignoreQueryFilters)
{
query = query.IgnoreQueryFilters();
}
if (orderBy != null)
{
return await orderBy(query).Select(selector).ToPagedListAsync(pageIndex, pageSize, 1, cancellationToken);
}
else
{
return await query.Select(selector).ToPagedListAsync(pageIndex, pageSize, 1, cancellationToken);
}
}
#endregion
#region GetFirstOrDefault
public virtual TEntity GetFirstOrDefault(
Expression<Func<TEntity, bool>> predicate = null,
Func<IQueryable<TEntity>, IOrderedQueryable<TEntity>> orderBy = null,
Func<IQueryable<TEntity>, IIncludableQueryable<TEntity, object>> include = null,
bool disableTracking = true,
bool ignoreQueryFilters = false)
{
IQueryable<TEntity> query = _dbSet;
if (disableTracking)
{
query = query.AsNoTracking();
}
if (include != null)
{
query = include(query);
}
if (predicate != null)
{
query = query.Where(predicate);
}
if (ignoreQueryFilters)
{
query = query.IgnoreQueryFilters();
}
if (orderBy != null)
{
return orderBy(query).FirstOrDefault();
}
else
{
return query.FirstOrDefault();
}
}
public virtual async Task<TEntity> GetFirstOrDefaultAsync(
Expression<Func<TEntity, bool>> predicate = null,
Func<IQueryable<TEntity>, IOrderedQueryable<TEntity>> orderBy = null,
Func<IQueryable<TEntity>, IIncludableQueryable<TEntity, object>> include = null,
bool disableTracking = true,
bool ignoreQueryFilters = false,
CancellationToken cancellationToken = default)
{
IQueryable<TEntity> query = _dbSet;
if (disableTracking)
{
query = query.AsNoTracking();
}
if (include != null)
{
query = include(query);
}
if (predicate != null)
{
query = query.Where(predicate);
}
if (ignoreQueryFilters)
{
query = query.IgnoreQueryFilters();
}
if (orderBy != null)
{
return await orderBy(query).FirstOrDefaultAsync(cancellationToken);
}
else
{
return await query.FirstOrDefaultAsync(cancellationToken);
}
}
public virtual TResult GetFirstOrDefault<TResult>(
Expression<Func<TEntity, TResult>> selector,
Expression<Func<TEntity, bool>> predicate = null,
Func<IQueryable<TEntity>, IOrderedQueryable<TEntity>> orderBy = null,
Func<IQueryable<TEntity>, IIncludableQueryable<TEntity, object>> include = null,
bool disableTracking = true,
bool ignoreQueryFilters = false)
{
IQueryable<TEntity> query = _dbSet;
if (disableTracking)
{
query = query.AsNoTracking();
}
if (include != null)
{
query = include(query);
}
if (predicate != null)
{
query = query.Where(predicate);
}
if (ignoreQueryFilters)
{
query = query.IgnoreQueryFilters();
}
if (orderBy != null)
{
return orderBy(query).Select(selector).FirstOrDefault();
}
else
{
return query.Select(selector).FirstOrDefault();
}
}
public virtual async Task<TResult> GetFirstOrDefaultAsync<TResult>(
Expression<Func<TEntity, TResult>> selector,
Expression<Func<TEntity, bool>> predicate = null,
Func<IQueryable<TEntity>, IOrderedQueryable<TEntity>> orderBy = null,
Func<IQueryable<TEntity>, IIncludableQueryable<TEntity, object>> include = null,
bool disableTracking = true,
bool ignoreQueryFilters = false,
CancellationToken cancellationToken = default)
{
IQueryable<TEntity> query = _dbSet;
if (disableTracking)
{
query = query.AsNoTracking();
}
if (include != null)
{
query = include(query);
}
if (predicate != null)
{
query = query.Where(predicate);
}
if (ignoreQueryFilters)
{
query = query.IgnoreQueryFilters();
}
if (orderBy != null)
{
return await orderBy(query).Select(selector).FirstOrDefaultAsync(cancellationToken);
}
else
{
return await query.Select(selector).FirstOrDefaultAsync(cancellationToken);
}
}
#endregion
#region Find
public virtual TEntity Find(params object[] keyValues) => _dbSet.Find(keyValues);
public virtual ValueTask<TEntity> FindAsync(params object[] keyValues) => _dbSet.FindAsync(keyValues);
public virtual ValueTask<TEntity> FindAsync(object[] keyValues, CancellationToken cancellationToken) => _dbSet.FindAsync(keyValues, cancellationToken);
#endregion
#region sql、count、exist
public virtual IQueryable<TEntity> FromSql(string sql, params object[] parameters) => _dbSet.FromSqlRaw(sql, parameters);
public virtual int Count(Expression<Func<TEntity, bool>> predicate = null)
{
if (predicate == null)
{
return _dbSet.Count();
}
else
{
return _dbSet.Count(predicate);
}
}
public virtual async Task<int> CountAsync(Expression<Func<TEntity, bool>> predicate = null)
{
if (predicate == null)
{
return await _dbSet.CountAsync();
}
else
{
return await _dbSet.CountAsync(predicate);
}
}
public virtual bool Exists(Expression<Func<TEntity, bool>> predicate = null)
{
if (predicate == null)
{
return _dbSet.Any();
}
else
{
return _dbSet.Any(predicate);
}
}
#endregion
#region Insert
public virtual TEntity Insert(TEntity entity)
{
return _dbSet.Add(entity).Entity;
}
public virtual void Insert(params TEntity[] entities) => _dbSet.AddRange(entities);
public virtual void Insert(IEnumerable<TEntity> entities) => _dbSet.AddRange(entities);
public virtual ValueTask<EntityEntry<TEntity>> InsertAsync(TEntity entity, CancellationToken cancellationToken = default(CancellationToken))
{
return _dbSet.AddAsync(entity, cancellationToken);
// Shadow properties?
//var property = _dbContext.Entry(entity).Property("Created");
//if (property != null) {
//property.CurrentValue = DateTime.Now;
//}
}
public virtual Task InsertAsync(params TEntity[] entities) => _dbSet.AddRangeAsync(entities);
public virtual Task InsertAsync(IEnumerable<TEntity> entities, CancellationToken cancellationToken = default(CancellationToken)) => _dbSet.AddRangeAsync(entities, cancellationToken);
#endregion
#region Update
public virtual void Update(TEntity entity)
{
_dbSet.Update(entity);
}
public virtual void UpdateAsync(TEntity entity)
{
_dbSet.Update(entity);
}
public virtual void Update(params TEntity[] entities) => _dbSet.UpdateRange(entities);
public virtual void Update(IEnumerable<TEntity> entities) => _dbSet.UpdateRange(entities);
#endregion
#region Delete
public virtual void Delete(TEntity entity) => _dbSet.Remove(entity);
public virtual void Delete(object id)
{
var entity = _dbSet.Find(id);
if (entity != null)
{
Delete(entity);
}
}
public virtual void Delete(params TEntity[] entities) => _dbSet.RemoveRange(entities);
public virtual void Delete(IEnumerable<TEntity> entities) => _dbSet.RemoveRange(entities);
#endregion
}
}
說明
- 封裝了通用的增刪改查操作
- 以Async方法名結尾的是異步操作
- 方法注釋都在接口中
- 查詢:
- GetAll查詢所有滿足條件的實體(注意性能)
- GetPagedList分頁查詢
- GetFirstOrDefault獲取滿足條件的第一個元素
- Find根據主鍵查找元素,比如給一個Id值
- FromSql原生sql查詢
- Count查詢數量
- Exists查詢是否存在
- 查詢中包含了很多條件:
- 分頁查詢默認每頁20筆數據
- 默認關閉了追蹤查詢
- 默認開啟了全局查詢過濾
- selector參數可以轉換查詢出來的數據為其他類型
工作單元
在MS.UnitOfWork
項目中添加UnitOfWork文件夾,在該文件夾下添加IUnitOfWork.cs
、UnitOfWork.cs
類。
IUnitOfWork.cs
using Microsoft.EntityFrameworkCore;
using Microsoft.EntityFrameworkCore.Storage;
using System;
using System.Linq;
using System.Threading.Tasks;
namespace MS.UnitOfWork
{
/// <summary>
/// 定義工作單元接口
/// </summary>
public interface IUnitOfWork<TContext> : IDisposable where TContext : DbContext
{
/// <summary>
/// 獲取DBContext
/// </summary>
/// <returns></returns>
TContext DbContext { get; }
/// <summary>
/// 開始一個事務
/// </summary>
/// <returns></returns>
IDbContextTransaction BeginTransaction();
/// <summary>
/// 獲取指定倉儲
/// </summary>
/// <typeparam name="TEntity"></typeparam>
/// <param name="hasCustomRepository">如有自定義倉儲設為True</param>
/// <returns></returns>
IRepository<TEntity> GetRepository<TEntity>(bool hasCustomRepository = false) where TEntity : class;
/// <summary>
/// DbContext提交修改
/// </summary>
/// <returns></returns>
int SaveChanges();
/// <summary>
/// DbContext提交修改(異步)
/// </summary>
/// <returns></returns>
Task<int> SaveChangesAsync();
/// <summary>
/// 執行原生sql語句
/// </summary>
/// <param name="sql">sql語句</param>
/// <param name="parameters">參數</param>
/// <returns></returns>
int ExecuteSqlCommand(string sql, params object[] parameters);
/// <summary>
/// 使用原生sql查詢來獲取指定數據
/// </summary>
/// <typeparam name="TEntity"></typeparam>
/// <param name="sql"></param>
/// <param name="parameters">參數</param>
/// <returns></returns>
IQueryable<TEntity> FromSql<TEntity>(string sql, params object[] parameters) where TEntity : class;
}
}
UnitOfWork.cs
using Microsoft.EntityFrameworkCore;
using Microsoft.EntityFrameworkCore.Infrastructure;
using Microsoft.EntityFrameworkCore.Storage;
using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading.Tasks;
namespace MS.UnitOfWork
{
/// <summary>
/// 工作單元的默認實現.
/// </summary>
/// <typeparam name="TContext"></typeparam>
public class UnitOfWork<TContext> : IUnitOfWork<TContext> where TContext : DbContext
{
protected readonly TContext _context;
protected bool _disposed = false;
protected Dictionary<Type, object> _repositories;
public UnitOfWork(TContext context)
{
_context = context ?? throw new ArgumentNullException(nameof(context));
}
/// <summary>
/// 獲取DbContext
/// </summary>
public TContext DbContext => _context;
/// <summary>
/// 開始一個事務
/// </summary>
/// <returns></returns>
public IDbContextTransaction BeginTransaction()
{
return _context.Database.BeginTransaction();
}
/// <summary>
/// 獲取指定倉儲
/// </summary>
/// <typeparam name="TEntity"></typeparam>
/// <param name="hasCustomRepository"></param>
/// <returns></returns>
public IRepository<TEntity> GetRepository<TEntity>(bool hasCustomRepository = false) where TEntity : class
{
if (_repositories == null)
{
_repositories = new Dictionary<Type, object>();
}
Type type = typeof(IRepository<TEntity>);
if (!_repositories.TryGetValue(type, out object repo))
{
IRepository<TEntity> newRepo = new Repository<TEntity>(_context);
_repositories.Add(type, newRepo);
return newRepo;
}
return (IRepository<TEntity>)repo;
}
/// <summary>
/// 執行原生sql語句
/// </summary>
/// <param name="sql">sql語句</param>
/// <param name="parameters">參數</param>
/// <returns></returns>
public int ExecuteSqlCommand(string sql, params object[] parameters) => _context.Database.ExecuteSqlRaw(sql, parameters);
/// <summary>
/// 使用原生sql查詢來獲取指定數據
/// </summary>
/// <typeparam name="TEntity"></typeparam>
/// <param name="sql"></param>
/// <param name="parameters">參數</param>
/// <returns></returns>
public IQueryable<TEntity> FromSql<TEntity>(string sql, params object[] parameters) where TEntity : class => _context.Set<TEntity>().FromSqlRaw(sql, parameters);
/// <summary>
/// DbContext提交修改
/// </summary>
/// <returns></returns>
public int SaveChanges()
{
return _context.SaveChanges();
}
/// <summary>
/// DbContext提交修改(異步)
/// </summary>
/// <returns></returns>
public async Task<int> SaveChangesAsync()
{
return await _context.SaveChangesAsync();
}
public void Dispose()
{
Dispose(true);
GC.SuppressFinalize(this);
}
protected virtual void Dispose(bool disposing)
{
if (!_disposed)
{
if (disposing)
{
// clear repositories
if (_repositories != null)
{
_repositories.Clear();
}
// dispose the db context.
_context.Dispose();
}
}
_disposed = true;
}
}
}
說明
- 從工作單元中獲取倉儲或DbContext數據庫上下文
- 如果要使用Transaction事務,也是從工作單元中開啟
- 通過倉儲修改數據后,使用工作單元SaveChanges提交修改
封裝Ioc注冊
在MS.UnitOfWork
項目中添加UnitOfWorkServiceExtensions.cs
類:
using Microsoft.EntityFrameworkCore;
using Microsoft.Extensions.DependencyInjection;
namespace MS.UnitOfWork
{
/// <summary>
///在 <see cref="IServiceCollection"/>中安裝工作單元依賴注入的擴展方法
/// </summary>
public static class UnitOfWorkServiceExtensions
{
/// <summary>
/// 在<see cref ="IServiceCollection"/>中注冊給定上下文作為服務的工作單元。
/// 同時注冊了dbcontext
/// </summary>
/// <typeparam name="TContext"></typeparam>
/// <param name="services"></param>
/// <remarks>此方法僅支持一個db上下文,如果多次調用,將拋出異常。</remarks>
/// <returns></returns>
public static IServiceCollection AddUnitOfWorkService<TContext>(this IServiceCollection services, System.Action<DbContextOptionsBuilder> action) where TContext : DbContext
{
//注冊dbcontext
services.AddDbContext<TContext>(action);
//注冊工作單元
services.AddScoped<IUnitOfWork<TContext>, UnitOfWork<TContext>>();
return services;
}
}
}
這樣一來,如果項目要使用該工作單元,直接在Startup中調用AddUnitOfWorkService注冊即可
項目完成后,如下圖所示:
使用方法展示
using (var tran = _unitOfWork.BeginTransaction())//開啟一個事務
{
Role newRow = _mapper.Map<Role>(viewModel);
newRow.Id = _idWorker.NextId();//獲取一個雪花Id
newRow.Creator = 1219490056771866624;//由於暫時還沒有做登錄,所以拿不到登錄者信息,先隨便寫一個后面再完善
newRow.CreateTime = DateTime.Now;
_unitOfWork.GetRepository<Role>().Insert(newRow);
await _unitOfWork.SaveChangesAsync();
await tran.CommitAsync();//提交事務
}
以上展示了工作單元開啟事務,用using包裹,直到tran.CommitAsync()
才提交事務,如果遇到錯誤,會自動回滾
//從數據庫中取出該記錄
var row = await _unitOfWork.GetRepository<Role>().FindAsync(viewModel.Id);//在viewModel.CheckField中已經獲取了一次用於檢查,所以此處不會重復再從數據庫取一次,有緩存
//修改對應的值
row.Name = viewModel.Name;
row.DisplayName = viewModel.DisplayName;
row.Remark = viewModel.Remark;
row.Modifier = 1219490056771866624;//由於暫時還沒有做登錄,所以拿不到登錄者信息,先隨便寫一個后面再完善
row.ModifyTime = DateTime.Now;
_unitOfWork.GetRepository<Role>().Update(row);
await _unitOfWork.SaveChangesAsync();//提交
- 以上展示了根據主鍵Id獲取數據,更新數據。
- 也可以GetFirstOrDefault獲取數據,disableTracking參數設為false,開啟追蹤,這樣獲取到的數據修改后,直接SaveChangesAsync,不需要update(關鍵就是開啟了追蹤,所以不需要update實體了)