Repository.cs
4.45 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
using System.Linq.Expressions;
using Microsoft.EntityFrameworkCore;
using Rcs.Domain.Repositories;
using Rcs.Infrastructure.DB.MsSql;
namespace Rcs.Infrastructure.DB.Repositories
{
/// <summary>
/// 通用仓储实现
/// </summary>
/// <typeparam name="TEntity">实体类型</typeparam>
public class Repository<TEntity> : IRepository<TEntity> where TEntity : class
{
protected readonly AppDbContext _context;
protected readonly DbSet<TEntity> _dbSet;
public Repository(AppDbContext context)
{
_context = context ?? throw new ArgumentNullException(nameof(context));
_dbSet = _context.Set<TEntity>();
}
#region 查询方法
public virtual async Task<TEntity?> GetByIdAsync(object id, CancellationToken cancellationToken = default)
{
return await _dbSet.FindAsync(new[] { id }, cancellationToken);
}
public virtual async Task<IEnumerable<TEntity>> GetAllAsync(CancellationToken cancellationToken = default)
{
return await _dbSet.ToListAsync(cancellationToken);
}
public virtual async Task<IEnumerable<TEntity>> FindAsync(
Expression<Func<TEntity, bool>> predicate,
CancellationToken cancellationToken = default)
{
return await _dbSet.Where(predicate).ToListAsync(cancellationToken);
}
public virtual async Task<TEntity?> FirstOrDefaultAsync(
Expression<Func<TEntity, bool>> predicate,
CancellationToken cancellationToken = default)
{
return await _dbSet.FirstOrDefaultAsync(predicate, cancellationToken);
}
public virtual IQueryable<TEntity> GetQueryable()
{
return _dbSet.AsQueryable();
}
public virtual async Task<bool> AnyAsync(
Expression<Func<TEntity, bool>> predicate,
CancellationToken cancellationToken = default)
{
return await _dbSet.AnyAsync(predicate, cancellationToken);
}
public virtual async Task<int> CountAsync(
Expression<Func<TEntity, bool>>? predicate = null,
CancellationToken cancellationToken = default)
{
if (predicate == null)
{
return await _dbSet.CountAsync(cancellationToken);
}
return await _dbSet.CountAsync(predicate, cancellationToken);
}
#endregion
#region 修改方法
public virtual async Task AddAsync(TEntity entity, CancellationToken cancellationToken = default)
{
await _dbSet.AddAsync(entity, cancellationToken);
}
public virtual async Task AddRangeAsync(IEnumerable<TEntity> entities, CancellationToken cancellationToken = default)
{
await _dbSet.AddRangeAsync(entities, cancellationToken);
}
public virtual async Task UpdateAsync(TEntity entity, CancellationToken cancellationToken = default)
{
_dbSet.Update(entity);
}
public virtual async Task UpdateRangeAsync(IEnumerable<TEntity> entities, CancellationToken cancellationToken = default)
{
_dbSet.UpdateRange(entities);
}
public virtual async Task DeleteAsync(TEntity entity, CancellationToken cancellationToken = default)
{
_dbSet.Remove(entity);
}
public virtual async Task DeleteByIdAsync(object id, CancellationToken cancellationToken = default)
{
var entity = await GetByIdAsync(id, cancellationToken);
if (entity != null)
{
await DeleteAsync(entity, cancellationToken);
}
}
public virtual async Task DeleteRangeAsync(IEnumerable<TEntity> entities, CancellationToken cancellationToken = default)
{
_dbSet.RemoveRange(entities);
}
public virtual async Task DeleteAsync(
Expression<Func<TEntity, bool>> predicate,
CancellationToken cancellationToken = default)
{
var entities = await FindAsync(predicate, cancellationToken);
await DeleteRangeAsync(entities, cancellationToken);
}
#endregion
#region 工作单元
public virtual async Task<int> SaveChangesAsync(CancellationToken cancellationToken = default)
{
return await _context.SaveChangesAsync(cancellationToken);
}
#endregion
}
}