diff --git a/GraphDiff/GraphDiff.Tests/GraphDiff.Tests.csproj b/GraphDiff/GraphDiff.Tests/GraphDiff.Tests.csproj
index 02231ed..7a12243 100644
--- a/GraphDiff/GraphDiff.Tests/GraphDiff.Tests.csproj
+++ b/GraphDiff/GraphDiff.Tests/GraphDiff.Tests.csproj
@@ -72,6 +72,7 @@
+
diff --git a/GraphDiff/GraphDiff.Tests/Tests/UpdateCollectionBehaviours.cs b/GraphDiff/GraphDiff.Tests/Tests/UpdateCollectionBehaviours.cs
new file mode 100644
index 0000000..940e3ec
--- /dev/null
+++ b/GraphDiff/GraphDiff.Tests/Tests/UpdateCollectionBehaviours.cs
@@ -0,0 +1,55 @@
+using System.Data.Entity;
+using System.Linq;
+using Microsoft.VisualStudio.TestTools.UnitTesting;
+using RefactorThis.GraphDiff.Tests.Models;
+
+namespace RefactorThis.GraphDiff.Tests.Tests
+{
+ [TestClass]
+ public class UpdateCollectionBehaviours : TestBase
+ {
+ [TestMethod]
+ public void ShouldAddSingleEntities()
+ {
+ var nodes = Enumerable.Range(1, 100)
+ .Select(i => new TestNode { Title = "Node" + i })
+ .ToArray();
+
+ using (var context = new TestDbContext())
+ {
+ var savedNodes = context.UpdateGraphs(nodes);
+ context.SaveChanges();
+
+ foreach (var node in savedNodes)
+ Assert.IsNotNull(context.Nodes.SingleOrDefault(p => p.Id == node.Id));
+ }
+ }
+
+ [TestMethod]
+ public void ShouldUpdateSingleEntities_Detached()
+ {
+ var nodes = Enumerable.Range(1, 100)
+ .Select(i => new TestNode { Title = "Node" + i })
+ .ToArray();
+
+ using (var context = new TestDbContext())
+ {
+ foreach (var node in nodes)
+ context.Nodes.Add(node);
+ context.SaveChanges();
+ } // Simulate detach
+
+ foreach (var node in nodes)
+ node.Title += "x";
+
+ using (var context = new TestDbContext())
+ {
+ context.UpdateGraphs(nodes);
+ context.SaveChanges();
+
+ foreach (var node in nodes)
+ Assert.IsTrue(context.Nodes.Single(p => p.Id == node.Id).Title.EndsWith("x"));
+ }
+ }
+ }
+}
\ No newline at end of file
diff --git a/GraphDiff/GraphDiff/DbContextExtensions.cs b/GraphDiff/GraphDiff/DbContextExtensions.cs
index 4b10302..595eed3 100644
--- a/GraphDiff/GraphDiff/DbContextExtensions.cs
+++ b/GraphDiff/GraphDiff/DbContextExtensions.cs
@@ -5,7 +5,9 @@
*/
using System;
+using System.Collections.Generic;
using System.Data.Entity;
+using System.Linq;
using System.Linq.Expressions;
using RefactorThis.GraphDiff.Internal;
using RefactorThis.GraphDiff.Internal.Caching;
@@ -56,6 +58,19 @@ public static T UpdateGraph(this DbContext context, T entity, UpdateParams up
return UpdateGraph(context, entity, null, null, updateParams);
}
+ ///
+ /// Merges a graph of entities with the data store.
+ ///
+ /// The type of the root entity
+ /// The database context to attach / detach.
+ /// The root entities.
+ /// Update configuration overrides
+ /// The attached entity graphs
+ public static IEnumerable UpdateGraphs(this DbContext context, IEnumerable entities, UpdateParams updateParams = null) where T : class
+ {
+ return UpdateGraphs(context, entities, null, null, updateParams);
+ }
+
///
/// Load an aggregate type from the database (including all related entities)
///
@@ -76,16 +91,25 @@ public static T LoadAggregate(this DbContext context, Expression(this DbContext context, T entity, Expression, object>> mapping,
string mappingScheme, UpdateParams updateParams) where T : class
{
if (entity == null)
throw new ArgumentNullException("entity");
+ return UpdateGraphs(context, new[] { entity }, mapping, mappingScheme, updateParams).First();
+ }
+
+ // other methods are convenience wrappers around this.
+ private static IEnumerable UpdateGraphs(this DbContext context, IEnumerable entities, Expression, object>> mapping,
+ string mappingScheme, UpdateParams updateParams) where T : class
+ {
+ if (entities == null)
+ throw new ArgumentNullException("entities");
+
var entityManager = new EntityManager(context);
var queryLoader = new QueryLoader(context, entityManager);
var register = new AggregateRegister(new CacheProvider());
@@ -94,7 +118,7 @@ private static T UpdateGraph(this DbContext context, T entity, Expression(context, queryLoader, entityManager, root);
var queryMode = updateParams != null ? updateParams.QueryMode : QueryMode.SingleQuery;
- return differ.Merge(entity, queryMode);
+ return differ.Merge(entities, queryMode);
}
private static GraphNode GetRootNode(Expression, object>> mapping, string mappingScheme, AggregateRegister register) where T : class
@@ -118,4 +142,4 @@ private static GraphNode GetRootNode(Expression,
return root;
}
}
-}
+}
\ No newline at end of file
diff --git a/GraphDiff/GraphDiff/GraphDiffConfiguration.cs b/GraphDiff/GraphDiff/GraphDiffConfiguration.cs
index 9aaf2b0..63cf61a 100644
--- a/GraphDiff/GraphDiff/GraphDiffConfiguration.cs
+++ b/GraphDiff/GraphDiff/GraphDiffConfiguration.cs
@@ -15,6 +15,14 @@ public static class GraphDiffConfiguration
/// If an entity is attached as an associated entity it will be automatically reloaded from the database
/// to ensure the EF local cache has the latest state.
///
- public static bool ReloadAssociatedEntitiesWhenAttached { get; set; }
+ public static bool ReloadAssociatedEntitiesWhenAttached { get; set; }
+
+ ///
+ /// If an entity has integer primary keys (int, uint, long, ulong) with
+ /// empty values, it is considered to be new and not persisted.
+ /// In this case the loading of a persisted version of this entity can
+ /// be skipped to increase the performance of inserts.
+ ///
+ public static bool SkipLoadingOfNewEntities { get; set; }
}
-}
+}
\ No newline at end of file
diff --git a/GraphDiff/GraphDiff/Internal/GraphDiffer.cs b/GraphDiff/GraphDiff/Internal/GraphDiffer.cs
index e63ccfd..e169d6d 100644
--- a/GraphDiff/GraphDiff/Internal/GraphDiffer.cs
+++ b/GraphDiff/GraphDiff/Internal/GraphDiffer.cs
@@ -1,12 +1,14 @@
using System;
+using System.Collections.Generic;
using System.Data.Entity;
+using System.Linq;
using RefactorThis.GraphDiff.Internal.Graph;
namespace RefactorThis.GraphDiff.Internal
{
internal interface IGraphDiffer where T : class
{
- T Merge(T updating, QueryMode queryMode = QueryMode.SingleQuery);
+ IEnumerable Merge(IEnumerable updatingItems, QueryMode queryMode = QueryMode.SingleQuery);
}
/// GraphDiff main entry point.
@@ -26,7 +28,7 @@ public GraphDiffer(DbContext dbContext, IQueryLoader queryLoader, IEntityManager
_entityManager = entityManager;
}
- public T Merge(T updating, QueryMode queryMode = QueryMode.SingleQuery)
+ public IEnumerable Merge(IEnumerable updatingItems, QueryMode queryMode = QueryMode.SingleQuery)
{
// todo query mode
bool isAutoDetectEnabled = _dbContext.Configuration.AutoDetectChangesEnabled;
@@ -37,30 +39,47 @@ public T Merge(T updating, QueryMode queryMode = QueryMode.SingleQuery)
// Get our entity with all includes needed, or add a new entity
var includeStrings = _root.GetIncludeStrings(_entityManager);
- T persisted = _queryLoader.LoadEntity(updating, includeStrings, queryMode);
- if (persisted == null)
+ var entityManager = new EntityManager(_dbContext);
+ var changeTracker = new ChangeTracker(_dbContext, entityManager);
+ var persistedItems = _queryLoader
+ .LoadEntities(updatingItems, includeStrings, queryMode)
+ .ToArray();
+ var index = 0;
+
+ foreach (var updating in updatingItems)
{
- // we are always working with 2 graphs, simply add a 'persisted' one if none exists,
- // this ensures that only the changes we make within the bounds of the mapping are attempted.
- persisted = (T)_dbContext.Set(updating.GetType()).Create();
+ // try to get persisted entity
+ if (index > persistedItems.Length - 1)
+ {
+ throw new InvalidOperationException(
+ String.Format("Could not load all persisted entities of type '{0}'.",
+ typeof(T).FullName));
+ }
- _dbContext.Set().Add(persisted);
- }
+ if (persistedItems[index] == null)
+ {
+ // we are always working with 2 graphs, simply add a 'persisted' one if none exists,
+ // this ensures that only the changes we make within the bounds of the mapping are attempted.
+ persistedItems[index] = (T)_dbContext.Set(updating.GetType()).Create();
- if (_dbContext.Entry(updating).State != EntityState.Detached)
- {
- throw new InvalidOperationException(
+ _dbContext.Set().Add(persistedItems[index]);
+ }
+
+ if (_dbContext.Entry(updating).State != EntityState.Detached)
+ {
+ throw new InvalidOperationException(
String.Format("Entity of type '{0}' is already in an attached state. GraphDiff supports detached entities only at this time. Please try AsNoTracking() or detach your entites before calling the UpdateGraph method.",
- typeof (T).FullName));
- }
+ typeof (T).FullName));
+ }
- // Perform recursive update
- var entityManager = new EntityManager(_dbContext);
- var changeTracker = new ChangeTracker(_dbContext, entityManager);
- _root.Update(changeTracker, entityManager, persisted, updating);
+ // Perform recursive update
+ _root.Update(changeTracker, entityManager, persistedItems[index], updating);
+
+ index++;
+ }
- return persisted;
+ return persistedItems;
}
finally
{
diff --git a/GraphDiff/GraphDiff/Internal/QueryLoader.cs b/GraphDiff/GraphDiff/Internal/QueryLoader.cs
index 7d37395..3105ecc 100644
--- a/GraphDiff/GraphDiff/Internal/QueryLoader.cs
+++ b/GraphDiff/GraphDiff/Internal/QueryLoader.cs
@@ -9,9 +9,9 @@ namespace RefactorThis.GraphDiff.Internal
{
/// Db load queries
internal interface IQueryLoader
- {
- T LoadEntity(T entity, IEnumerable includeStrings, QueryMode queryMode) where T : class;
- T LoadEntity(Expression> keyPredicate, IEnumerable includeStrings, QueryMode queryMode) where T : class;
+ {
+ IEnumerable LoadEntities(IEnumerable entities, IEnumerable includeStrings, QueryMode queryMode) where T : class;
+ IEnumerable LoadEntities(Expression> keyPredicate, IEnumerable includeStrings, QueryMode queryMode) where T : class;
}
internal class QueryLoader : IQueryLoader
@@ -25,29 +25,65 @@ public QueryLoader(DbContext context, IEntityManager entityManager)
_context = context;
}
- public T LoadEntity(T entity, IEnumerable includeStrings, QueryMode queryMode) where T : class
+ public IEnumerable LoadEntities(IEnumerable entities, IEnumerable includeStrings, QueryMode queryMode) where T : class
{
- if (entity == null)
+ if (entities == null)
{
- throw new ArgumentNullException("entity");
+ throw new ArgumentNullException("entities");
}
- var keyPredicate = CreateKeyPredicateExpression(entity);
- return LoadEntity(keyPredicate, includeStrings, queryMode);
+ var keyProperties = _entityManager.GetPrimaryKeyFieldsFor(typeof(T)).ToArray();
+ var keyValues = entities.Select(e => keyProperties.Select(x => x.GetValue(e, null)).ToArray()).ToArray();
+ var keyPredicate = CreateKeyPredicateExpression(entities, keyProperties, keyValues);
+ var entityCount = keyValues.Length;
+
+ // skip loading of entities with empty integral key propeties (new entitites)
+ if (keyPredicate == null)
+ return new T[entityCount];
+
+ // load presisted entities
+ var loadedEntities = LoadEntities(keyPredicate, includeStrings, queryMode);
+
+ // skip sort for single entities
+ if (entityCount == 1)
+ return new[] { loadedEntities.FirstOrDefault() };
+
+ // restore order of loaded entities
+ var orderedEntities = new List(entityCount);
+ foreach (var entity in entities)
+ {
+ var entityKeyValues = keyValues[orderedEntities.Count];
+ orderedEntities.Add(loadedEntities.FirstOrDefault(x =>
+ {
+ // find matching item by key values
+ for (var i = 0; i < entityKeyValues.Length; i++)
+ if (!Equals(entityKeyValues[i], keyProperties[i].GetValue(x, null)))
+ return false;
+ return true;
+ }));
+ }
+
+ // validate count
+ if (orderedEntities.Count != entityCount)
+ throw new InvalidOperationException(
+ String.Format("Could not load all {0} persisted items of type '{1}'.",
+ entityCount, typeof(T).FullName));
+
+ return orderedEntities;
}
- public T LoadEntity(Expression> keyPredicate, IEnumerable includeStrings, QueryMode queryMode) where T : class
+ public IEnumerable LoadEntities(Expression> keyPredicate, IEnumerable includeStrings, QueryMode queryMode) where T : class
{
if (queryMode == QueryMode.SingleQuery)
{
var query = _context.Set().AsQueryable();
query = includeStrings.Aggregate(query, (current, include) => current.Include(include));
- return query.SingleOrDefault(keyPredicate);
+ return query.Where(keyPredicate).ToArray();
}
if (queryMode == QueryMode.MultipleQuery)
{
- // This is experimental - needs some testing.
+ // This is experimental - needs some testing.
foreach (var include in includeStrings)
{
var query = _context.Set().AsQueryable();
@@ -55,30 +91,84 @@ public T LoadEntity(Expression> keyPredicate, IEnumerable().Local.AsQueryable().SingleOrDefault(keyPredicate);
+ return _context.Set().Local.AsQueryable().Where(keyPredicate).ToArray();
}
throw new ArgumentOutOfRangeException("queryMode", "Unknown QueryMode");
}
- private Expression> CreateKeyPredicateExpression(T entity)
+ private Expression> CreateKeyPredicateExpression(IEnumerable entities, IList keyProperties, IEnumerable> keyValues)
{
// get key properties of T
- var keyProperties = _entityManager.GetPrimaryKeyFieldsFor(typeof(T)).ToList();
+ ParameterExpression parameter = Expression.Parameter(typeof(T));
+ Expression resultExpression = null;
+ var keyValuesEnumerator = keyValues.GetEnumerator();
+
+ foreach (var entity in entities)
+ {
+ if (!keyValuesEnumerator.MoveNext())
+ throw new InvalidOperationException(
+ String.Format("Number of key values does not match number of entities with type '{0}'.",
+ typeof(T).FullName));
- ParameterExpression parameter = Expression.Parameter(typeof(T));
- Expression expression = CreateEqualsExpression(entity, keyProperties[0], parameter);
- for (int i = 1; i < keyProperties.Count; i++)
- {
- expression = Expression.And(expression, CreateEqualsExpression(entity, keyProperties[i], parameter));
- }
+ // prevent key predicate with empty values
+ if (GraphDiffConfiguration.SkipLoadingOfNewEntities &&
+ AllIntegralKeysEmpty(keyProperties, keyValuesEnumerator.Current))
+ continue;
+
+ // create predicate for entity
+ var itemExpression = CreateEqualsExpression(keyValuesEnumerator.Current[0], keyProperties[0], parameter);
+ for (int i = 1; i < keyProperties.Count; i++)
+ itemExpression = Expression.AndAlso(itemExpression,
+ CreateEqualsExpression(keyValuesEnumerator.Current[i], keyProperties[i], parameter));
+
+ // compose all entity predicates
+ resultExpression = resultExpression != null
+ ? Expression.OrElse(resultExpression, itemExpression)
+ : itemExpression;
+ }
- return Expression.Lambda>(expression, parameter);
- }
-
- private static Expression CreateEqualsExpression(object entity, PropertyInfo keyProperty, Expression parameter)
- {
- return Expression.Equal(Expression.Property(parameter, keyProperty), Expression.Constant(keyProperty.GetValue(entity, null), keyProperty.PropertyType));
- }
+ return resultExpression != null
+ ? Expression.Lambda>(resultExpression, parameter)
+ : null;
+ }
+
+ private static Expression CreateEqualsExpression(object keyValue, PropertyInfo keyProperty, Expression parameter)
+ {
+ return Expression.Equal(Expression.Property(parameter, keyProperty), Expression.Constant(keyValue, keyProperty.PropertyType));
+ }
+
+ private static bool AllIntegralKeysEmpty(IList properties, IList