1. 首先EF的Repository需要抽象的行为提到接口中。
public interface IXXXContext : IDisposable
{
IXXXContext NewInstance();
// db sets
DbSet<AAABBB> aaa { get; set; }
...
// common
Database Database { get; }
DbContextConfiguration Configuration { get; }
int SaveChanges();
Task<int> SaveChangesAsync();
// store pros
...
IStorePro1 StorePro1 { get; }
...
}
然后就可以使用DataContext和TestDataContext实现这个接口。其中TestDataContext是在UT中使用的,DataContext是自动生成的。
TestDataContext还需要以下几个类进行模拟。
public class TestDbSet<TEntity> : DbSet<TEntity>, IQueryable, IEnumerable<TEntity>, IDbAsyncEnumerable<TEntity>
where TEntity : class
{
ObservableCollection<TEntity> _data;
IQueryable _query;
public TestDbSet()
{
_data = new ObservableCollection<TEntity>();
_query = _data.AsQueryable();
}
public override TEntity Add(TEntity item)
{
_data.Add(item);
return item;
}
public override TEntity Remove(TEntity item)
{
_data.Remove(item);
return item;
}
public override TEntity Attach(TEntity item)
{
_data.Add(item);
return item;
}
public override TEntity Create()
{
return Activator.CreateInstance<TEntity>();
}
public override TDerivedEntity Create<TDerivedEntity>()
{
return Activator.CreateInstance<TDerivedEntity>();
}
public override ObservableCollection<TEntity> Local
{
get { return _data; }
}
Type IQueryable.ElementType
{
get { return _query.ElementType; }
}
Expression IQueryable.Expression
{
get { return _query.Expression; }
}
IQueryProvider IQueryable.Provider
{
get { return new TestDbAsyncQueryProvider<TEntity>(_query.Provider); }
}
System.Collections.IEnumerator System.Collections.IEnumerable.GetEnumerator()
{
return _data.GetEnumerator();
}
IEnumerator<TEntity> IEnumerable<TEntity>.GetEnumerator()
{
return _data.GetEnumerator();
}
IDbAsyncEnumerator<TEntity> IDbAsyncEnumerable<TEntity>.GetAsyncEnumerator()
{
return new TestDbAsyncEnumerator<TEntity>(_data.GetEnumerator());
}
}
internal class TestDbAsyncQueryProvider<TEntity> : IDbAsyncQueryProvider
{
private readonly IQueryProvider _inner;
internal TestDbAsyncQueryProvider(IQueryProvider inner)
{
_inner = inner;
}
public IQueryable CreateQuery(Expression expression)
{
return new TestDbAsyncEnumerable<TEntity>(expression);
}
public IQueryable<TElement> CreateQuery<TElement>(Expression expression)
{
return new TestDbAsyncEnumerable<TElement>(expression);
}
public object Execute(Expression expression)
{
return _inner.Execute(expression);
}
public TResult Execute<TResult>(Expression expression)
{
return _inner.Execute<TResult>(expression);
}
public Task<object> ExecuteAsync(Expression expression, CancellationToken cancellationToken)
{
return Task.FromResult(Execute(expression));
}
public Task<TResult> ExecuteAsync<TResult>(Expression expression, CancellationToken cancellationToken)
{
return Task.FromResult(Execute<TResult>(expression));
}
}
internal class TestDbAsyncEnumerable<T> : EnumerableQuery<T>, IDbAsyncEnumerable<T>, IQueryable<T>
{
public TestDbAsyncEnumerable(IEnumerable<T> enumerable)
: base(enumerable)
{ }
public TestDbAsyncEnumerable(Expression expression)
: base(expression)
{ }
public IDbAsyncEnumerator<T> GetAsyncEnumerator()
{
return new TestDbAsyncEnumerator<T>(this.AsEnumerable().GetEnumerator());
}
IDbAsyncEnumerator IDbAsyncEnumerable.GetAsyncEnumerator()
{
return GetAsyncEnumerator();
}
IQueryProvider IQueryable.Provider
{
get { return new TestDbAsyncQueryProvider<T>(this); }
}
}
internal class TestDbAsyncEnumerator<T> : IDbAsyncEnumerator<T>
{
private readonly IEnumerator<T> _inner;
public TestDbAsyncEnumerator(IEnumerator<T> inner)
{
_inner = inner;
}
public void Dispose()
{
_inner.Dispose();
}
public Task<bool> MoveNextAsync(CancellationToken cancellationToken)
{
return Task.FromResult(_inner.MoveNext());
}
public T Current
{
get { return _inner.Current; }
}
object IDbAsyncEnumerator.Current
{
get { return Current; }
}
}
使用示例:
[TestMethod]
public void TestMethod1()
{
var mockSet = new Mock<DbSet<BLACKLISTED_TICKET>>();
var mockContext = new Mock<TicketDataContextTest>();
mockContext.Setup(m => m.BLACKLISTED_TICKET).Returns(new TestDbSet<BLACKLISTED_TICKET>());
var context = mockContext.Object;
context.BLACKLISTED_TICKET.Add(new BLACKLISTED_TICKET()
{
TicketNumber = "aaa",
CreatedDateTime = DateTime.Now,
Id = 1,
ModifiedDateTime = DateTime.Now,
STATUS = "1"
});
Assert.IsTrue(context.BLACKLISTED_TICKET.First().Id == 1);
}
如果使用了存储过程,需要额外定义存储过程的接口。
例如:
IStorePro {
...
}
StorePro : IStorePro{
...
}
StoreProFake: IStorePro{
}
然后IDataContext负责返回存储过程的实例
IDataContext{
...
IStorePro GetStorePro();
...
}