diff --git a/GraphDiff/GraphDiff.Tests/GraphDiff.Tests.csproj b/GraphDiff/GraphDiff.Tests/GraphDiff.Tests.csproj index 02231ed..7a12243 100644 --- a/GraphDiff/GraphDiff.Tests/GraphDiff.Tests.csproj +++ b/GraphDiff/GraphDiff.Tests/GraphDiff.Tests.csproj @@ -72,6 +72,7 @@ + diff --git a/GraphDiff/GraphDiff.Tests/Tests/UpdateCollectionBehaviours.cs b/GraphDiff/GraphDiff.Tests/Tests/UpdateCollectionBehaviours.cs new file mode 100644 index 0000000..940e3ec --- /dev/null +++ b/GraphDiff/GraphDiff.Tests/Tests/UpdateCollectionBehaviours.cs @@ -0,0 +1,55 @@ +using System.Data.Entity; +using System.Linq; +using Microsoft.VisualStudio.TestTools.UnitTesting; +using RefactorThis.GraphDiff.Tests.Models; + +namespace RefactorThis.GraphDiff.Tests.Tests +{ + [TestClass] + public class UpdateCollectionBehaviours : TestBase + { + [TestMethod] + public void ShouldAddSingleEntities() + { + var nodes = Enumerable.Range(1, 100) + .Select(i => new TestNode { Title = "Node" + i }) + .ToArray(); + + using (var context = new TestDbContext()) + { + var savedNodes = context.UpdateGraphs(nodes); + context.SaveChanges(); + + foreach (var node in savedNodes) + Assert.IsNotNull(context.Nodes.SingleOrDefault(p => p.Id == node.Id)); + } + } + + [TestMethod] + public void ShouldUpdateSingleEntities_Detached() + { + var nodes = Enumerable.Range(1, 100) + .Select(i => new TestNode { Title = "Node" + i }) + .ToArray(); + + using (var context = new TestDbContext()) + { + foreach (var node in nodes) + context.Nodes.Add(node); + context.SaveChanges(); + } // Simulate detach + + foreach (var node in nodes) + node.Title += "x"; + + using (var context = new TestDbContext()) + { + context.UpdateGraphs(nodes); + context.SaveChanges(); + + foreach (var node in nodes) + Assert.IsTrue(context.Nodes.Single(p => p.Id == node.Id).Title.EndsWith("x")); + } + } + } +} \ No newline at end of file diff --git a/GraphDiff/GraphDiff/DbContextExtensions.cs b/GraphDiff/GraphDiff/DbContextExtensions.cs index 4b10302..595eed3 100644 --- a/GraphDiff/GraphDiff/DbContextExtensions.cs +++ b/GraphDiff/GraphDiff/DbContextExtensions.cs @@ -5,7 +5,9 @@ */ using System; +using System.Collections.Generic; using System.Data.Entity; +using System.Linq; using System.Linq.Expressions; using RefactorThis.GraphDiff.Internal; using RefactorThis.GraphDiff.Internal.Caching; @@ -56,6 +58,19 @@ public static T UpdateGraph(this DbContext context, T entity, UpdateParams up return UpdateGraph(context, entity, null, null, updateParams); } + /// + /// Merges a graph of entities with the data store. + /// + /// The type of the root entity + /// The database context to attach / detach. + /// The root entities. + /// Update configuration overrides + /// The attached entity graphs + public static IEnumerable UpdateGraphs(this DbContext context, IEnumerable entities, UpdateParams updateParams = null) where T : class + { + return UpdateGraphs(context, entities, null, null, updateParams); + } + /// /// Load an aggregate type from the database (including all related entities) /// @@ -76,16 +91,25 @@ public static T LoadAggregate(this DbContext context, Expression(this DbContext context, T entity, Expression, object>> mapping, string mappingScheme, UpdateParams updateParams) where T : class { if (entity == null) throw new ArgumentNullException("entity"); + return UpdateGraphs(context, new[] { entity }, mapping, mappingScheme, updateParams).First(); + } + + // other methods are convenience wrappers around this. + private static IEnumerable UpdateGraphs(this DbContext context, IEnumerable entities, Expression, object>> mapping, + string mappingScheme, UpdateParams updateParams) where T : class + { + if (entities == null) + throw new ArgumentNullException("entities"); + var entityManager = new EntityManager(context); var queryLoader = new QueryLoader(context, entityManager); var register = new AggregateRegister(new CacheProvider()); @@ -94,7 +118,7 @@ private static T UpdateGraph(this DbContext context, T entity, Expression(context, queryLoader, entityManager, root); var queryMode = updateParams != null ? updateParams.QueryMode : QueryMode.SingleQuery; - return differ.Merge(entity, queryMode); + return differ.Merge(entities, queryMode); } private static GraphNode GetRootNode(Expression, object>> mapping, string mappingScheme, AggregateRegister register) where T : class @@ -118,4 +142,4 @@ private static GraphNode GetRootNode(Expression, return root; } } -} +} \ No newline at end of file diff --git a/GraphDiff/GraphDiff/GraphDiffConfiguration.cs b/GraphDiff/GraphDiff/GraphDiffConfiguration.cs index 9aaf2b0..63cf61a 100644 --- a/GraphDiff/GraphDiff/GraphDiffConfiguration.cs +++ b/GraphDiff/GraphDiff/GraphDiffConfiguration.cs @@ -15,6 +15,14 @@ public static class GraphDiffConfiguration /// If an entity is attached as an associated entity it will be automatically reloaded from the database /// to ensure the EF local cache has the latest state. /// - public static bool ReloadAssociatedEntitiesWhenAttached { get; set; } + public static bool ReloadAssociatedEntitiesWhenAttached { get; set; } + + /// + /// If an entity has integer primary keys (int, uint, long, ulong) with + /// empty values, it is considered to be new and not persisted. + /// In this case the loading of a persisted version of this entity can + /// be skipped to increase the performance of inserts. + /// + public static bool SkipLoadingOfNewEntities { get; set; } } -} +} \ No newline at end of file diff --git a/GraphDiff/GraphDiff/Internal/GraphDiffer.cs b/GraphDiff/GraphDiff/Internal/GraphDiffer.cs index e63ccfd..e169d6d 100644 --- a/GraphDiff/GraphDiff/Internal/GraphDiffer.cs +++ b/GraphDiff/GraphDiff/Internal/GraphDiffer.cs @@ -1,12 +1,14 @@ using System; +using System.Collections.Generic; using System.Data.Entity; +using System.Linq; using RefactorThis.GraphDiff.Internal.Graph; namespace RefactorThis.GraphDiff.Internal { internal interface IGraphDiffer where T : class { - T Merge(T updating, QueryMode queryMode = QueryMode.SingleQuery); + IEnumerable Merge(IEnumerable updatingItems, QueryMode queryMode = QueryMode.SingleQuery); } /// GraphDiff main entry point. @@ -26,7 +28,7 @@ public GraphDiffer(DbContext dbContext, IQueryLoader queryLoader, IEntityManager _entityManager = entityManager; } - public T Merge(T updating, QueryMode queryMode = QueryMode.SingleQuery) + public IEnumerable Merge(IEnumerable updatingItems, QueryMode queryMode = QueryMode.SingleQuery) { // todo query mode bool isAutoDetectEnabled = _dbContext.Configuration.AutoDetectChangesEnabled; @@ -37,30 +39,47 @@ public T Merge(T updating, QueryMode queryMode = QueryMode.SingleQuery) // Get our entity with all includes needed, or add a new entity var includeStrings = _root.GetIncludeStrings(_entityManager); - T persisted = _queryLoader.LoadEntity(updating, includeStrings, queryMode); - if (persisted == null) + var entityManager = new EntityManager(_dbContext); + var changeTracker = new ChangeTracker(_dbContext, entityManager); + var persistedItems = _queryLoader + .LoadEntities(updatingItems, includeStrings, queryMode) + .ToArray(); + var index = 0; + + foreach (var updating in updatingItems) { - // we are always working with 2 graphs, simply add a 'persisted' one if none exists, - // this ensures that only the changes we make within the bounds of the mapping are attempted. - persisted = (T)_dbContext.Set(updating.GetType()).Create(); + // try to get persisted entity + if (index > persistedItems.Length - 1) + { + throw new InvalidOperationException( + String.Format("Could not load all persisted entities of type '{0}'.", + typeof(T).FullName)); + } - _dbContext.Set().Add(persisted); - } + if (persistedItems[index] == null) + { + // we are always working with 2 graphs, simply add a 'persisted' one if none exists, + // this ensures that only the changes we make within the bounds of the mapping are attempted. + persistedItems[index] = (T)_dbContext.Set(updating.GetType()).Create(); - if (_dbContext.Entry(updating).State != EntityState.Detached) - { - throw new InvalidOperationException( + _dbContext.Set().Add(persistedItems[index]); + } + + if (_dbContext.Entry(updating).State != EntityState.Detached) + { + throw new InvalidOperationException( String.Format("Entity of type '{0}' is already in an attached state. GraphDiff supports detached entities only at this time. Please try AsNoTracking() or detach your entites before calling the UpdateGraph method.", - typeof (T).FullName)); - } + typeof (T).FullName)); + } - // Perform recursive update - var entityManager = new EntityManager(_dbContext); - var changeTracker = new ChangeTracker(_dbContext, entityManager); - _root.Update(changeTracker, entityManager, persisted, updating); + // Perform recursive update + _root.Update(changeTracker, entityManager, persistedItems[index], updating); + + index++; + } - return persisted; + return persistedItems; } finally { diff --git a/GraphDiff/GraphDiff/Internal/QueryLoader.cs b/GraphDiff/GraphDiff/Internal/QueryLoader.cs index 7d37395..3105ecc 100644 --- a/GraphDiff/GraphDiff/Internal/QueryLoader.cs +++ b/GraphDiff/GraphDiff/Internal/QueryLoader.cs @@ -9,9 +9,9 @@ namespace RefactorThis.GraphDiff.Internal { /// Db load queries internal interface IQueryLoader - { - T LoadEntity(T entity, IEnumerable includeStrings, QueryMode queryMode) where T : class; - T LoadEntity(Expression> keyPredicate, IEnumerable includeStrings, QueryMode queryMode) where T : class; + { + IEnumerable LoadEntities(IEnumerable entities, IEnumerable includeStrings, QueryMode queryMode) where T : class; + IEnumerable LoadEntities(Expression> keyPredicate, IEnumerable includeStrings, QueryMode queryMode) where T : class; } internal class QueryLoader : IQueryLoader @@ -25,29 +25,65 @@ public QueryLoader(DbContext context, IEntityManager entityManager) _context = context; } - public T LoadEntity(T entity, IEnumerable includeStrings, QueryMode queryMode) where T : class + public IEnumerable LoadEntities(IEnumerable entities, IEnumerable includeStrings, QueryMode queryMode) where T : class { - if (entity == null) + if (entities == null) { - throw new ArgumentNullException("entity"); + throw new ArgumentNullException("entities"); } - var keyPredicate = CreateKeyPredicateExpression(entity); - return LoadEntity(keyPredicate, includeStrings, queryMode); + var keyProperties = _entityManager.GetPrimaryKeyFieldsFor(typeof(T)).ToArray(); + var keyValues = entities.Select(e => keyProperties.Select(x => x.GetValue(e, null)).ToArray()).ToArray(); + var keyPredicate = CreateKeyPredicateExpression(entities, keyProperties, keyValues); + var entityCount = keyValues.Length; + + // skip loading of entities with empty integral key propeties (new entitites) + if (keyPredicate == null) + return new T[entityCount]; + + // load presisted entities + var loadedEntities = LoadEntities(keyPredicate, includeStrings, queryMode); + + // skip sort for single entities + if (entityCount == 1) + return new[] { loadedEntities.FirstOrDefault() }; + + // restore order of loaded entities + var orderedEntities = new List(entityCount); + foreach (var entity in entities) + { + var entityKeyValues = keyValues[orderedEntities.Count]; + orderedEntities.Add(loadedEntities.FirstOrDefault(x => + { + // find matching item by key values + for (var i = 0; i < entityKeyValues.Length; i++) + if (!Equals(entityKeyValues[i], keyProperties[i].GetValue(x, null))) + return false; + return true; + })); + } + + // validate count + if (orderedEntities.Count != entityCount) + throw new InvalidOperationException( + String.Format("Could not load all {0} persisted items of type '{1}'.", + entityCount, typeof(T).FullName)); + + return orderedEntities; } - public T LoadEntity(Expression> keyPredicate, IEnumerable includeStrings, QueryMode queryMode) where T : class + public IEnumerable LoadEntities(Expression> keyPredicate, IEnumerable includeStrings, QueryMode queryMode) where T : class { if (queryMode == QueryMode.SingleQuery) { var query = _context.Set().AsQueryable(); query = includeStrings.Aggregate(query, (current, include) => current.Include(include)); - return query.SingleOrDefault(keyPredicate); + return query.Where(keyPredicate).ToArray(); } if (queryMode == QueryMode.MultipleQuery) { - // This is experimental - needs some testing. + // This is experimental - needs some testing. foreach (var include in includeStrings) { var query = _context.Set().AsQueryable(); @@ -55,30 +91,84 @@ public T LoadEntity(Expression> keyPredicate, IEnumerable().Local.AsQueryable().SingleOrDefault(keyPredicate); + return _context.Set().Local.AsQueryable().Where(keyPredicate).ToArray(); } throw new ArgumentOutOfRangeException("queryMode", "Unknown QueryMode"); } - private Expression> CreateKeyPredicateExpression(T entity) + private Expression> CreateKeyPredicateExpression(IEnumerable entities, IList keyProperties, IEnumerable> keyValues) { // get key properties of T - var keyProperties = _entityManager.GetPrimaryKeyFieldsFor(typeof(T)).ToList(); + ParameterExpression parameter = Expression.Parameter(typeof(T)); + Expression resultExpression = null; + var keyValuesEnumerator = keyValues.GetEnumerator(); + + foreach (var entity in entities) + { + if (!keyValuesEnumerator.MoveNext()) + throw new InvalidOperationException( + String.Format("Number of key values does not match number of entities with type '{0}'.", + typeof(T).FullName)); - ParameterExpression parameter = Expression.Parameter(typeof(T)); - Expression expression = CreateEqualsExpression(entity, keyProperties[0], parameter); - for (int i = 1; i < keyProperties.Count; i++) - { - expression = Expression.And(expression, CreateEqualsExpression(entity, keyProperties[i], parameter)); - } + // prevent key predicate with empty values + if (GraphDiffConfiguration.SkipLoadingOfNewEntities && + AllIntegralKeysEmpty(keyProperties, keyValuesEnumerator.Current)) + continue; + + // create predicate for entity + var itemExpression = CreateEqualsExpression(keyValuesEnumerator.Current[0], keyProperties[0], parameter); + for (int i = 1; i < keyProperties.Count; i++) + itemExpression = Expression.AndAlso(itemExpression, + CreateEqualsExpression(keyValuesEnumerator.Current[i], keyProperties[i], parameter)); + + // compose all entity predicates + resultExpression = resultExpression != null + ? Expression.OrElse(resultExpression, itemExpression) + : itemExpression; + } - return Expression.Lambda>(expression, parameter); - } - - private static Expression CreateEqualsExpression(object entity, PropertyInfo keyProperty, Expression parameter) - { - return Expression.Equal(Expression.Property(parameter, keyProperty), Expression.Constant(keyProperty.GetValue(entity, null), keyProperty.PropertyType)); - } + return resultExpression != null + ? Expression.Lambda>(resultExpression, parameter) + : null; + } + + private static Expression CreateEqualsExpression(object keyValue, PropertyInfo keyProperty, Expression parameter) + { + return Expression.Equal(Expression.Property(parameter, keyProperty), Expression.Constant(keyValue, keyProperty.PropertyType)); + } + + private static bool AllIntegralKeysEmpty(IList properties, IList values) + { + for (var i = 0; i < properties.Count; i++) + { + // detect empty numeric key properties (new entity) + if (properties[i].PropertyType == typeof(int)) + { + if ((int)values[i] == 0) + continue; + } + else if (properties[i].PropertyType == typeof(uint)) + { + if ((uint)values[i] == 0) + continue; + } + else if (properties[i].PropertyType == typeof(long)) + { + if ((long)values[i] == 0) + continue; + } + else if (properties[i].PropertyType == typeof(ulong)) + { + if ((ulong)values[i] == 0) + continue; + } + + // skip this optimization for other types + return false; + } + + return true; + } } -} +} \ No newline at end of file