diff --git a/src/NHibernate.Test/DriverTest/BulkInsertTests.cs b/src/NHibernate.Test/DriverTest/BulkInsertTests.cs new file mode 100644 index 00000000000..5b45daa70b3 --- /dev/null +++ b/src/NHibernate.Test/DriverTest/BulkInsertTests.cs @@ -0,0 +1,60 @@ +using System.Linq; +using NHibernate.Cfg; +using NHibernate.DomainModel.Northwind.Entities; +using NHibernate.Linq; +using NHibernate.Test.Linq; +using NUnit.Framework; + +namespace NHibernate.Test.DriverTest +{ + [TestFixture] + public class BulkInsertTests : LinqTestCase + { + protected override void Configure(Configuration configuration) + { + configuration.SetProperty(Environment.Hbm2ddlAuto, SchemaAutoAction.Create.ToString()); + + base.Configure(configuration); + } + + [Test] + public void CanBulkInsertEntitiesWithComponents() + { + //NH-3675 + using (var statelessSession = session.SessionFactory.OpenStatelessSession()) + using (statelessSession.BeginTransaction()) + { + var customers = new Customer[] { new Customer { Address = new Address("street", "city", "region", "postalCode", "country", "phoneNumber", "fax"), CompanyName = "Company", ContactName = "Contact", ContactTitle = "Title", CustomerId = "12345" } }; + + statelessSession.CreateQuery("delete from Customer").ExecuteUpdate(); + + statelessSession.BulkInsert(customers); + + var count = statelessSession.Query().Count(); + + Assert.AreEqual(customers.Count(), count); + } + } + + [Test] + public void CanBulkInsertEntitiesWithComponentsAndAssociations() + { + //NH-3675 + using (var statelessSession = session.SessionFactory.OpenStatelessSession()) + using (statelessSession.BeginTransaction()) + { + var superior = new Employee { Address = new Address("street", "city", "region", "zip", "country", "phone", "fax"), BirthDate = System.DateTime.Now, EmployeeId = 1, Extension = "1", FirstName = "Superior", LastName = "Last" }; + var employee = new Employee { Address = new Address("street", "city", "region", "zip", "country", "phone", "fax"), BirthDate = System.DateTime.Now, EmployeeId = 2, Extension = "2", FirstName = "Employee", LastName = "Last", Superior = superior }; + var employees = new Employee[] { superior, employee }; + + statelessSession.CreateQuery("delete from Employee").ExecuteUpdate(); + + statelessSession.BulkInsert(employees); + + var count = statelessSession.Query().Count(); + + Assert.AreEqual(employees.Count(), count); + } + } + } +} diff --git a/src/NHibernate/Cfg/Environment.cs b/src/NHibernate/Cfg/Environment.cs index 1368e26d3bb..7546c27069d 100644 --- a/src/NHibernate/Cfg/Environment.cs +++ b/src/NHibernate/Cfg/Environment.cs @@ -65,6 +65,10 @@ public static string Version } } + public const String BulkProviderClass = "adonet.bulk_provider_class"; + public const String BulkProviderTimeout = "adonet.bulk_provider_timeout"; + public const String BulkProviderBatchSize = "adonet.bulk_provider_batch_size"; + public const string ConnectionProvider = "connection.provider"; public const string ConnectionDriver = "connection.driver_class"; public const string ConnectionString = "connection.connection_string"; diff --git a/src/NHibernate/Driver/BulkProvider.cs b/src/NHibernate/Driver/BulkProvider.cs new file mode 100644 index 00000000000..9ae5469efcd --- /dev/null +++ b/src/NHibernate/Driver/BulkProvider.cs @@ -0,0 +1,63 @@ +using System; +using System.Collections.Generic; +using NHibernate.Engine; +using NHibernate.Id; +using NHibernate.Persister.Entity; +using Environment = NHibernate.Cfg.Environment; + +namespace NHibernate.Driver +{ + public abstract class BulkProvider : IDisposable + { + protected BulkProvider() + { + } + + ~BulkProvider() + { + this.Dispose(false); + } + + public Int32 BatchSize { get; set; } + + public Int32 Timeout { get; set; } + + public abstract void Insert(ISessionImplementor session, IEnumerable entities) where T : class; + + public virtual void Initialize(IDictionary properties) + { + var timeout = string.Empty; + var batchSize = string.Empty; + + if (properties.TryGetValue(Environment.BulkProviderTimeout, out timeout)) + { + this.Timeout = Convert.ToInt32(timeout); + } + + if (properties.TryGetValue(Environment.BulkProviderBatchSize, out batchSize)) + { + this.BatchSize = Convert.ToInt32(batchSize); + } + } + + protected virtual void FillIdentifier(ISessionImplementor session, IEntityPersister persister, Object entity) + { + if (!(persister.IdentifierGenerator is Assigned) && !(persister.IdentifierGenerator is ForeignGenerator)) + { + var id = persister.IdentifierGenerator.Generate(session, entity); + + persister.SetIdentifier(entity, id, session.EntityMode); + } + } + + protected virtual void Dispose(Boolean disposing) + { + } + + public void Dispose() + { + this.Dispose(true); + GC.SuppressFinalize(this); + } + } +} diff --git a/src/NHibernate/Driver/DefaultBulkProvider.cs b/src/NHibernate/Driver/DefaultBulkProvider.cs new file mode 100644 index 00000000000..bc03324027e --- /dev/null +++ b/src/NHibernate/Driver/DefaultBulkProvider.cs @@ -0,0 +1,24 @@ +using System; +using System.Collections.Generic; +using NHibernate.Engine; + +namespace NHibernate.Driver +{ + sealed class DefaultBulkProvider : BulkProvider + { + public override void Insert(ISessionImplementor session, IEnumerable entities) + { + var statelessSession = session as IStatelessSession; + + if (statelessSession == null) + { + throw new InvalidOperationException("Insert can only be called with stateless sessions."); + } + + foreach (var entity in entities) + { + statelessSession.Insert(entity); + } + } + } +} diff --git a/src/NHibernate/Driver/DriverBase.cs b/src/NHibernate/Driver/DriverBase.cs index 938a83e3dff..deb676a2062 100644 --- a/src/NHibernate/Driver/DriverBase.cs +++ b/src/NHibernate/Driver/DriverBase.cs @@ -21,6 +21,11 @@ public abstract class DriverBase : IDriver, ISqlParameterFormatter private int commandTimeout; private bool prepareSql; + public virtual BulkProvider GetBulkProvider() + { + return new DefaultBulkProvider(); + } + public virtual void Configure(IDictionary settings) { // Command timeout diff --git a/src/NHibernate/Driver/IDriver.cs b/src/NHibernate/Driver/IDriver.cs index 2916037d1f5..76b9548e8d0 100644 --- a/src/NHibernate/Driver/IDriver.cs +++ b/src/NHibernate/Driver/IDriver.cs @@ -32,6 +32,11 @@ namespace NHibernate.Driver /// public interface IDriver { + /// + /// Returns a bulk provider for the current driver. + /// + BulkProvider GetBulkProvider(); + /// /// Configure the driver using . /// diff --git a/src/NHibernate/Driver/OracleDataClientBulkProvider.cs b/src/NHibernate/Driver/OracleDataClientBulkProvider.cs new file mode 100644 index 00000000000..d9d1bc678dd --- /dev/null +++ b/src/NHibernate/Driver/OracleDataClientBulkProvider.cs @@ -0,0 +1,55 @@ +using System; +using System.Collections.Generic; +using System.Data; +using System.Linq; +using System.Reflection; +using NHibernate.Engine; + +namespace NHibernate.Driver +{ + public class OracleDataClientBulkProvider : TableBasedBulkProvider + { + public const String BulkProviderOptions = "adonet.bulk_provider_options"; + + private static readonly System.Type bulkCopyOptionsType = System.Type.GetType("Oracle.DataAccess.Client.OracleBulkCopyOptions, Oracle.DataAccess"); + private static readonly System.Type bulkCopyType = System.Type.GetType("Oracle.DataAccess.Client.OracleBulkCopy, Oracle.DataAccess"); + private static readonly PropertyInfo batchSizeProperty = bulkCopyType.GetProperty("BatchSize"); + private static readonly PropertyInfo bulkCopyTimeoutProperty = bulkCopyType.GetProperty("BulkCopyTimeout"); + private static readonly PropertyInfo destinationTableNameProperty = bulkCopyType.GetProperty("DestinationTableName"); + private static readonly MethodInfo writeToServerMethod = bulkCopyType.GetMethod("WriteToServer", new System.Type[] { typeof(DataTable) }); + + public Int32 Options { get; set; } + + public Int32 NotifyAfter { get; set; } + + public override void Initialize(IDictionary properties) + { + base.Initialize(properties); + + var bulkProviderOptions = String.Empty; + + if (properties.TryGetValue(BulkProviderOptions, out bulkProviderOptions)) + { + this.Options = Convert.ToInt32(bulkProviderOptions); + } + } + + public override void Insert(ISessionImplementor session, IEnumerable entities) + { + if (entities.Any() == true) + { + foreach (var table in this.GetTables(session, entities)) + { + using (var copy = Activator.CreateInstance(bulkCopyType, session.Connection, Enum.ToObject(bulkCopyOptionsType, this.Options)) as IDisposable) + { + batchSizeProperty.SetValue(copy, this.BatchSize, null); + bulkCopyTimeoutProperty.SetValue(copy, this.Timeout, null); + destinationTableNameProperty.SetValue(copy, table.TableName, null); + + writeToServerMethod.Invoke(copy, new Object[] { table }); + } + } + } + } + } +} diff --git a/src/NHibernate/Driver/OracleDataClientDriver.cs b/src/NHibernate/Driver/OracleDataClientDriver.cs index 465f9da81a5..e1a2f53ca02 100644 --- a/src/NHibernate/Driver/OracleDataClientDriver.cs +++ b/src/NHibernate/Driver/OracleDataClientDriver.cs @@ -52,6 +52,11 @@ public OracleDataClientDriver() oracleDbTypeXmlType = Enum.Parse(oracleDbTypeEnum, "XmlType"); } + public override BulkProvider GetBulkProvider() + { + return new OracleDataClientBulkProvider(); + } + /// public override bool UseNamedPrefixInSql { diff --git a/src/NHibernate/Driver/SqlBulkProvider.cs b/src/NHibernate/Driver/SqlBulkProvider.cs new file mode 100644 index 00000000000..5bb7c22388f --- /dev/null +++ b/src/NHibernate/Driver/SqlBulkProvider.cs @@ -0,0 +1,46 @@ +using System; +using System.Collections.Generic; +using System.Data.SqlClient; +using System.Linq; +using NHibernate.Engine; +using NHibernate.Transaction; + +namespace NHibernate.Driver +{ + public class SqlBulkProvider : TableBasedBulkProvider + { + public const String BulkProviderOptions = "adonet.bulk_provider_options"; + + public SqlBulkCopyOptions Options { get; set; } + + + public override void Initialize(IDictionary properties) + { + base.Initialize(properties); + + var bulkProviderOptions = String.Empty; + + if (properties.TryGetValue(BulkProviderOptions, out bulkProviderOptions)) + { + this.Options = (SqlBulkCopyOptions)Enum.Parse(typeof(SqlBulkCopyOptions), bulkProviderOptions, true); + } + } + + public override void Insert(ISessionImplementor session, IEnumerable entities) + { + if (entities.Any() == true) + { + var con = session.Connection as SqlConnection; + var tx = (session.ConnectionManager.Transaction as AdoTransaction).GetNativeTransaction() as SqlTransaction; + + foreach (var table in this.GetTables(session, entities)) + { + using (var copy = new SqlBulkCopy(con, this.Options, tx) { BatchSize = this.BatchSize, BulkCopyTimeout = this.Timeout, DestinationTableName = table.TableName }) + { + copy.WriteToServer(table); + } + } + } + } + } +} diff --git a/src/NHibernate/Driver/SqlClientDriver.cs b/src/NHibernate/Driver/SqlClientDriver.cs index fadec4ea5ab..c977f8bc8d8 100644 --- a/src/NHibernate/Driver/SqlClientDriver.cs +++ b/src/NHibernate/Driver/SqlClientDriver.cs @@ -26,6 +26,11 @@ public class SqlClientDriver : DriverBase, IEmbeddedBatcherFactoryProvider public const byte MaxDateTime2 = 8; public const byte MaxDateTimeOffset = 10; + public override BulkProvider GetBulkProvider() + { + return new SqlBulkProvider(); + } + /// /// Creates an uninitialized object for /// the SqlClientDriver. diff --git a/src/NHibernate/Driver/TableBasedBulkProvider.cs b/src/NHibernate/Driver/TableBasedBulkProvider.cs new file mode 100644 index 00000000000..8e156e03879 --- /dev/null +++ b/src/NHibernate/Driver/TableBasedBulkProvider.cs @@ -0,0 +1,197 @@ +using System; +using System.Collections; +using System.Collections.Concurrent; +using System.Collections.Generic; +using System.Data; +using System.Linq; +using System.Reflection; +using NHibernate.Engine; +using NHibernate.Id; +using NHibernate.Persister.Entity; +using NHibernate.Type; + +namespace NHibernate.Driver +{ + public abstract class TableBasedBulkProvider : BulkProvider + { + protected virtual IEnumerable GetTables(ISessionImplementor session, IEnumerable entities) + { + if (session.EntityMode != EntityMode.Poco) + { + throw new InvalidOperationException(String.Format("Entity mode {0} is not supported for bulk inserts.", session.EntityMode)); + } + + var tables = new Dictionary(); + + foreach (var entityTypes in entities.GroupBy(x => x.GetType())) + { + var entityType = entityTypes.Key; + var persister = session.GetEntityPersister(entityType.FullName, null) as AbstractEntityPersister; + var table = new DataTable(persister.TableName); + tables[table.TableName] = table; + + var map = new Hashtable(); + + if (persister.IdentifierGenerator is IPostInsertIdentifierGenerator) + { + throw new ArgumentException("Post insert identifier generators cannot be used for bulk inserts."); + } + + if (String.IsNullOrWhiteSpace(persister.IdentifierPropertyName) == true) + { + throw new ArgumentException("Entities without an identity property cannot be used for bulk inserts."); + } + + for (var c = 0; c < persister.IdentifierColumnNames.Length; ++c) + { + var columnName = persister.IdentifierColumnNames[c]; + + if (persister.IdentifierType is ComponentType) + { + table.Columns.Add(columnName, (persister.IdentifierType as ComponentType).ReturnedClass.GetProperty((persister.IdentifierType as ComponentType).PropertyNames[c]).PropertyType).ExtendedProperties["PropertyName"] = String.Concat(persister.IdentifierPropertyName, ".", (persister.IdentifierType as ComponentType).PropertyNames[c]); + } + else if (persister.EntityMetamodel.Properties[c].Type is OneToOneType) + { + table.Columns.Add(columnName, (persister.EntityMetamodel.Properties[c].Type as OneToOneType).GetIdentifierOrUniqueKeyPropertyName(session.Factory).GetType()).ExtendedProperties["PropertyName"] = String.Concat(persister.EntityMetamodel.Properties[c].Name, ".", (persister.EntityMetamodel.Properties[c].Type as OneToOneType).GetIdentifierOrUniqueKeyPropertyName(session.Factory)); + } + else + { + table.Columns.Add(columnName, persister.IdentifierType.ReturnedClass).ExtendedProperties["PropertyName"] = persister.IdentifierPropertyName; + } + } + + for (var i = 0; i < persister.EntityMetamodel.Properties.Length; ++i) + { + if (persister.EntityMetamodel.PropertyInsertability[i] == false) + { + continue; + } + + if (persister.EntityMetamodel.Properties[i].Type.IsCollectionType == true) + { + continue; + } + + var columnNames = persister.GetPropertyColumnNames(persister.EntityMetamodel.Properties[i].Name); + + for (var c = 0; c < columnNames.Length; ++c) + { + var columnName = columnNames[c]; + + if (persister.EntityMetamodel.Properties[i].Type is ComponentType) + { + table.Columns.Add(columnName, (persister.EntityMetamodel.Properties[i].Type as ComponentType).ReturnedClass.GetProperty((persister.EntityMetamodel.Properties[i].Type as ComponentType).PropertyNames[c]).PropertyType).ExtendedProperties["PropertyName"] = String.Concat(persister.EntityMetamodel.Properties[i].Name, ".", (persister.EntityMetamodel.Properties[i].Type as ComponentType).PropertyNames[c]); + } + else if (persister.EntityMetamodel.Properties[i].Type is OneToOneType) + { + table.Columns.Add(columnName, (persister.EntityMetamodel.Properties[i].Type as OneToOneType).GetIdentifierOrUniqueKeyPropertyName(session.Factory).GetType()).ExtendedProperties["PropertyName"] = String.Concat(persister.EntityMetamodel.Properties[i].Name, ".", (persister.EntityMetamodel.Properties[i].Type as OneToOneType).GetIdentifierOrUniqueKeyPropertyName(session.Factory)); + } + else if (persister.EntityMetamodel.Properties[i].Type is ManyToOneType) + { + table.Columns.Add(columnName, (persister.EntityMetamodel.Properties[i].Type as ManyToOneType).GetIdentifierOrUniqueKeyPropertyName(session.Factory).GetType()).ExtendedProperties["PropertyName"] = String.Concat(persister.EntityMetamodel.Properties[i].Name, ".", (persister.EntityMetamodel.Properties[i].Type as ManyToOneType).GetIdentifierOrUniqueKeyPropertyName(session.Factory)); + } + else + { + table.Columns.Add(columnName, persister.EntityMetamodel.Properties[i].Type.ReturnedClass).ExtendedProperties["PropertyName"] = persister.EntityMetamodel.Properties[i].Name; + } + } + } + + table.BeginLoadData(); + + foreach (var entity in entityTypes) + { + var row = table.NewRow(); + + for (var c = 0; c < table.Columns.Count; ++c) + { + var value = Eval(entity, table.Columns[c].ExtendedProperties["PropertyName"].ToString()) ?? DBNull.Value; + row[c] = value; + } + + table.Rows.Add(row); + } + + table.EndLoadData(); + } + + return (tables.Values); + } + + private object Eval(object instance, string path) + { + if (instance == null) + { + return null; + } + + var context = instance; + object value = null; + var parts = path.Split('.'); + + for (var i = 0; i < parts.Length; ++i) + { + value = GetPropertyOrFieldValue(context, parts[i]); + context = value; + } + + return value; + } + + private MemberInfo GetPropertyOrField(object instance, string memberName) + { + if (instance == null) + { + return null; + } + + var property = instance.GetType().GetProperty(memberName, BindingFlags.Instance | BindingFlags.NonPublic | BindingFlags.Public); + + if (property != null) + { + return property; + } + + var field = FindField(instance.GetType(), memberName); + + return field; + } + + private object GetPropertyOrFieldValue(object instance, string memberName) + { + if (instance == null) + { + return null; + } + + var member = GetPropertyOrField(instance, memberName); + + if (member != null) + { + return (member is FieldInfo) ? (member as FieldInfo).GetValue(instance) : (member as PropertyInfo).GetValue(instance, null); + } + + throw new InvalidOperationException(string.Format("Member named {0} does not exist in type {1}.", memberName, instance.GetType().FullName)); + } + + private FieldInfo FindField(System.Type type, string fieldName) + { + var currentType = type; + var field = null as FieldInfo; + + while (currentType != typeof (object)) + { + field = currentType.GetField(fieldName, BindingFlags.Instance | BindingFlags.NonPublic | BindingFlags.Public); + + if (field != null) + { + return field; + } + + currentType = currentType.BaseType; + } + + return field; + } + } +} diff --git a/src/NHibernate/SessionExtensions.cs b/src/NHibernate/SessionExtensions.cs new file mode 100644 index 00000000000..b8708dac9c4 --- /dev/null +++ b/src/NHibernate/SessionExtensions.cs @@ -0,0 +1,38 @@ +using System; +using System.Collections.Generic; +using NHibernate.Engine; +using NHibernate.Exceptions; +using Environment = NHibernate.Cfg.Environment; + +namespace NHibernate +{ + public static class SessionExtensions + { + public static void BulkInsert(this IStatelessSession session, IEnumerable entities) where T : class + { + BulkInsert(session.GetSessionImplementation(), entities); + } + + private static void BulkInsert(ISessionImplementor session, IEnumerable entities) where T : class + { + using (var provider = session.Factory.ConnectionProvider.Driver.GetBulkProvider()) + { + if (provider == null) + { + throw new InvalidOperationException("Current driver does not support bulk inserts."); + } + + provider.Initialize(Environment.Properties); + + try + { + provider.Insert(session, entities); + } + catch (Exception e) + { + throw ADOExceptionHelper.Convert(session.Factory.SQLExceptionConverter, e, "could not execute bulk insert."); + } + } + } + } +} diff --git a/src/NHibernate/Transaction/AdoTransaction.cs b/src/NHibernate/Transaction/AdoTransaction.cs index 3b3f79ac5d6..cbdf06877e2 100644 --- a/src/NHibernate/Transaction/AdoTransaction.cs +++ b/src/NHibernate/Transaction/AdoTransaction.cs @@ -454,5 +454,10 @@ private void NotifyLocalSynchsAfterTransactionCompletion(bool success) } } } + + internal IDbTransaction GetNativeTransaction() + { + return trans; + } } }