diff --git a/refinery/tests/mysql.rs b/refinery/tests/mysql.rs index a9c487bd..a76f2f29 100644 --- a/refinery/tests/mysql.rs +++ b/refinery/tests/mysql.rs @@ -371,6 +371,7 @@ mod mysql { false, Target::Latest, DEFAULT_TABLE_NAME, + None, ) .unwrap(); @@ -483,6 +484,7 @@ mod mysql { false, Target::Latest, DEFAULT_TABLE_NAME, + None, ) .unwrap_err(); @@ -519,6 +521,7 @@ mod mysql { false, Target::Latest, DEFAULT_TABLE_NAME, + None, ) .unwrap_err(); @@ -568,6 +571,7 @@ mod mysql { false, Target::Latest, DEFAULT_TABLE_NAME, + None, ) .unwrap_err(); match err.kind() { diff --git a/refinery/tests/mysql_async.rs b/refinery/tests/mysql_async.rs index a3892ae4..557cc676 100644 --- a/refinery/tests/mysql_async.rs +++ b/refinery/tests/mysql_async.rs @@ -374,6 +374,7 @@ mod mysql_async { false, Target::Latest, DEFAULT_TABLE_NAME, + None, ) .await .unwrap(); @@ -494,6 +495,7 @@ mod mysql_async { false, Target::Latest, DEFAULT_TABLE_NAME, + None, ) .await .unwrap_err(); @@ -539,6 +541,7 @@ mod mysql_async { false, Target::Latest, DEFAULT_TABLE_NAME, + None, ) .await .unwrap_err(); @@ -591,6 +594,7 @@ mod mysql_async { false, Target::Latest, DEFAULT_TABLE_NAME, + None, ) .await .unwrap_err(); diff --git a/refinery/tests/postgres.rs b/refinery/tests/postgres.rs index ec10a2c8..96e69992 100644 --- a/refinery/tests/postgres.rs +++ b/refinery/tests/postgres.rs @@ -351,6 +351,7 @@ mod postgres { false, Target::Latest, DEFAULT_TABLE_NAME, + None, ) .unwrap(); @@ -455,6 +456,7 @@ mod postgres { false, Target::Latest, DEFAULT_TABLE_NAME, + None, ) .unwrap_err(); @@ -488,6 +490,7 @@ mod postgres { false, Target::Latest, DEFAULT_TABLE_NAME, + None, ) .unwrap_err(); @@ -534,6 +537,7 @@ mod postgres { false, Target::Latest, DEFAULT_TABLE_NAME, + None, ) .unwrap_err(); match err.kind() { diff --git a/refinery/tests/rusqlite.rs b/refinery/tests/rusqlite.rs index 258af424..84acb8d2 100644 --- a/refinery/tests/rusqlite.rs +++ b/refinery/tests/rusqlite.rs @@ -449,6 +449,7 @@ mod rusqlite { false, Target::Latest, DEFAULT_TABLE_NAME, + None, ) .unwrap(); @@ -583,6 +584,7 @@ mod rusqlite { false, Target::Latest, DEFAULT_TABLE_NAME, + None, ) .unwrap_err(); @@ -614,6 +616,7 @@ mod rusqlite { false, Target::Latest, DEFAULT_TABLE_NAME, + None, ) .unwrap_err(); @@ -658,6 +661,7 @@ mod rusqlite { false, Target::Latest, DEFAULT_TABLE_NAME, + None, ) .unwrap_err(); match err.kind() { diff --git a/refinery/tests/tiberius.rs b/refinery/tests/tiberius.rs index f6346063..9abee01c 100644 --- a/refinery/tests/tiberius.rs +++ b/refinery/tests/tiberius.rs @@ -138,6 +138,7 @@ mod tiberius { false, Target::Latest, DEFAULT_TABLE_NAME, + None, ) .await .unwrap_err(); @@ -189,6 +190,7 @@ mod tiberius { false, Target::Latest, DEFAULT_TABLE_NAME, + None, ) .await .unwrap_err(); @@ -253,6 +255,7 @@ mod tiberius { false, Target::Latest, DEFAULT_TABLE_NAME, + None, ) .await .unwrap_err(); @@ -523,6 +526,7 @@ mod tiberius { false, Target::Latest, DEFAULT_TABLE_NAME, + None, ) .await .unwrap(); diff --git a/refinery/tests/tokio_postgres.rs b/refinery/tests/tokio_postgres.rs index 2f7f7fc6..f58a4b02 100644 --- a/refinery/tests/tokio_postgres.rs +++ b/refinery/tests/tokio_postgres.rs @@ -495,6 +495,7 @@ mod tokio_postgres { false, Target::Latest, DEFAULT_TABLE_NAME, + None, ) .await .unwrap(); @@ -631,6 +632,7 @@ mod tokio_postgres { false, Target::Latest, DEFAULT_TABLE_NAME, + None, ) .await .unwrap_err(); @@ -677,6 +679,7 @@ mod tokio_postgres { false, Target::Latest, DEFAULT_TABLE_NAME, + None, ) .await .unwrap_err(); @@ -735,6 +738,7 @@ mod tokio_postgres { false, Target::Latest, DEFAULT_TABLE_NAME, + None, ) .await .unwrap_err(); @@ -757,7 +761,8 @@ mod tokio_postgres { .set_db_name("postgres") .set_db_user("postgres") .set_db_host("localhost") - .set_db_port("5432"); + .set_db_port("5432") + .set_db_schema("public"); let migrations = get_migrations(); let runner = Runner::new(&migrations) @@ -801,7 +806,8 @@ mod tokio_postgres { .set_db_name("postgres") .set_db_user("postgres") .set_db_host("localhost") - .set_db_port("5432"); + .set_db_port("5432") + .set_db_schema("public"); let migrations = get_migrations(); let runner = Runner::new(&migrations) @@ -842,7 +848,8 @@ mod tokio_postgres { .set_db_name("postgres") .set_db_user("postgres") .set_db_host("localhost") - .set_db_port("5432"); + .set_db_port("5432") + .set_db_schema("public"); let migrations = get_migrations(); let runner = Runner::new(&migrations) diff --git a/refinery_cli/src/cli.rs b/refinery_cli/src/cli.rs index 03d13593..e2b28b17 100644 --- a/refinery_cli/src/cli.rs +++ b/refinery_cli/src/cli.rs @@ -44,6 +44,10 @@ pub struct MigrateArgs { #[clap(long, default_value = "refinery_schema_history")] pub table_name: String, + /// Set explicit migration table schema + #[clap(long)] + pub table_schema: Option, + /// Should abort if divergent migrations are found #[clap(short)] pub divergent: bool, diff --git a/refinery_cli/src/migrate.rs b/refinery_cli/src/migrate.rs index b61fa876..9fa023ef 100644 --- a/refinery_cli/src/migrate.rs +++ b/refinery_cli/src/migrate.rs @@ -19,6 +19,7 @@ pub fn handle_migration_command(args: MigrateArgs) -> anyhow::Result<()> { args.env_var.as_deref(), &args.path, &args.table_name, + args.table_schema.as_deref(), )?; Ok(()) } @@ -34,6 +35,7 @@ fn run_migrations( env_var_opt: Option<&str>, path: &Path, table_name: &str, + table_schema: Option<&str>, ) -> anyhow::Result<()> { let migration_files_path = find_migration_files(path, MigrationType::Sql)?; let mut migrations = Vec::new(); @@ -79,6 +81,7 @@ fn run_migrations( .set_abort_divergent(divergent) .set_abort_missing(missing) .set_migration_table_name(table_name) + .set_migration_table_schema(table_schema.or(config.db_schema())) .run_async(&mut config) .await })?; @@ -96,6 +99,7 @@ fn run_migrations( .set_abort_missing(missing) .set_target(target) .set_migration_table_name(table_name) + .set_migration_table_schema(table_schema.or(config.db_schema())) .run(&mut config)?; } else { panic!("tried to migrate async from config for a {:?} database, but it's matching feature was not enabled!", _db_type); diff --git a/refinery_cli/src/setup.rs b/refinery_cli/src/setup.rs index e6e06f90..b212de41 100644 --- a/refinery_cli/src/setup.rs +++ b/refinery_cli/src/setup.rs @@ -89,5 +89,13 @@ fn get_config_from_input() -> Result { io::stdin().read_line(&mut db_name)?; config = config.set_db_name(db_name.trim()); + print!("Enter optional schema name (empty to use the default schema): "); + io::stdout().flush()?; + let mut db_schema = String::new(); + io::stdin().read_line(&mut db_schema)?; + db_schema = db_schema.trim().to_string(); + if !db_schema.is_empty() { + config = config.set_db_schema(&db_schema); + } Ok(config) } diff --git a/refinery_core/src/config.rs b/refinery_core/src/config.rs index 7eba833f..f989dfd7 100644 --- a/refinery_core/src/config.rs +++ b/refinery_core/src/config.rs @@ -34,6 +34,7 @@ impl Config { db_user: None, db_pass: None, db_name: None, + db_schema: None, #[cfg(feature = "tiberius-config")] trust_cert: false, }, @@ -139,6 +140,10 @@ impl Config { self.main.db_port.as_deref() } + pub fn db_schema(&self) -> Option<&str> { + self.main.db_schema.as_deref() + } + pub fn set_db_user(self, db_user: &str) -> Config { Config { main: Main { @@ -183,6 +188,15 @@ impl Config { }, } } + + pub fn set_db_schema(self, db_schema: &str) -> Config { + Config { + main: Main { + db_schema: Some(db_schema.into()), + ..self.main + }, + } + } } impl TryFrom for Config { @@ -238,6 +252,9 @@ impl TryFrom for Config { db_user: Some(url.username().to_string()), db_pass: url.password().map(|r| r.to_string()), db_name: Some(url.path().trim_start_matches('/').to_string()), + db_schema: url + .query_pairs() + .find_map(|(name, value)| (name == "currentSchema").then(|| value.to_string())), #[cfg(feature = "tiberius-config")] trust_cert, }, @@ -270,6 +287,7 @@ struct Main { db_user: Option, db_pass: Option, db_name: Option, + db_schema: Option, #[cfg(feature = "tiberius-config")] #[serde(default)] trust_cert: bool, @@ -421,7 +439,8 @@ mod tests { db_port = \"5432\" \n db_user = \"root\" \n db_pass = \"1234\" \n - db_name = \"refinery\""; + db_name = \"refinery\" \n + db_schema = \"public\""; let config: Config = toml::from_str(config).unwrap(); diff --git a/refinery_core/src/drivers/config.rs b/refinery_core/src/drivers/config.rs index 639ff108..0200d163 100644 --- a/refinery_core/src/drivers/config.rs +++ b/refinery_core/src/drivers/config.rs @@ -204,6 +204,7 @@ impl crate::Migrate for Config { grouped: bool, target: Target, migration_table_name: &str, + migration_table_schema: Option<&str>, ) -> Result { with_connection!(self, |mut conn| { crate::Migrate::migrate( @@ -214,6 +215,7 @@ impl crate::Migrate for Config { grouped, target, migration_table_name, + migration_table_schema, ) }) } @@ -267,6 +269,7 @@ impl crate::AsyncMigrate for Config { grouped: bool, target: Target, migration_table_name: &str, + migration_table_schema: Option<&str>, ) -> Result { with_connection_async!(self, move |mut conn| async move { crate::AsyncMigrate::migrate( @@ -277,6 +280,7 @@ impl crate::AsyncMigrate for Config { grouped, target, migration_table_name, + migration_table_schema, ) .await }) diff --git a/refinery_core/src/runner.rs b/refinery_core/src/runner.rs index af8e2ab5..3258ad79 100644 --- a/refinery_core/src/runner.rs +++ b/refinery_core/src/runner.rs @@ -226,6 +226,7 @@ pub struct Runner { migrations: Vec, target: Target, migration_table_name: String, + migration_table_schema: Option, } impl Runner { @@ -238,6 +239,7 @@ impl Runner { abort_missing: true, migrations: migrations.to_vec(), migration_table_name: DEFAULT_MIGRATION_TABLE_NAME.into(), + migration_table_schema: None, } } @@ -291,7 +293,8 @@ impl Runner { where C: Migrate, { - Migrate::get_last_applied_migration(conn, &self.migration_table_name) + let migration_table_name = self.get_migration_table_name(); + Migrate::get_last_applied_migration(conn, &migration_table_name) } /// Queries the database asynchronously for the last applied migration, returns None if there aren't applied Migrations @@ -302,7 +305,8 @@ impl Runner { where C: AsyncMigrate + Send, { - AsyncMigrate::get_last_applied_migration(conn, &self.migration_table_name).await + let migration_table_name = self.get_migration_table_name(); + AsyncMigrate::get_last_applied_migration(conn, &migration_table_name).await } /// Queries the database for all previous applied migrations @@ -310,7 +314,8 @@ impl Runner { where C: Migrate, { - Migrate::get_applied_migrations(conn, &self.migration_table_name) + let migration_table_name = self.get_migration_table_name(); + Migrate::get_applied_migrations(conn, &migration_table_name) } /// Queries the database asynchronously for all previous applied migrations @@ -321,7 +326,8 @@ impl Runner { where C: AsyncMigrate + Send, { - AsyncMigrate::get_applied_migrations(conn, &self.migration_table_name).await + let migration_table_name = self.get_migration_table_name(); + AsyncMigrate::get_applied_migrations(conn, &migration_table_name).await } /// Set the table name to use for the migrations table. The default name is `refinery_schema_history` @@ -345,6 +351,16 @@ impl Runner { self } + /// Set the explicit schema to use for the migrations table. + /// The default is `None`, which means the default schema is used. + pub fn set_migration_table_schema>( + &mut self, + migration_table_schema: Option, + ) -> &mut Self { + self.migration_table_schema = migration_table_schema.map(|s| s.as_ref().to_string()); + self + } + /// Creates an iterator over pending migrations, applying each before returning /// the result from `next()`. If a migration fails, the iterator will return that /// result and further calls to `next()` will return `None`. @@ -371,6 +387,7 @@ impl Runner { self.grouped, self.target, &self.migration_table_name, + self.migration_table_schema.as_deref(), ) } @@ -387,9 +404,18 @@ impl Runner { self.grouped, self.target, &self.migration_table_name, + self.migration_table_schema.as_deref(), ) .await } + + fn get_migration_table_name(&self) -> String { + if let Some(schema) = &self.migration_table_schema { + format!(r#""{schema}"."{}""#, self.migration_table_name) + } else { + self.migration_table_name.clone() + } + } } pub struct RunIterator<'a, C> { @@ -412,6 +438,7 @@ where runner.abort_divergent, runner.abort_missing, &runner.migration_table_name, + runner.migration_table_schema.as_deref(), ) .unwrap(), ), diff --git a/refinery_core/src/traits/async.rs b/refinery_core/src/traits/async.rs index fc9e4f75..291dc467 100644 --- a/refinery_core/src/traits/async.rs +++ b/refinery_core/src/traits/async.rs @@ -1,7 +1,7 @@ use crate::error::WrapMigrationError; use crate::traits::{ - insert_migration_query, verify_migrations, ASSERT_MIGRATIONS_TABLE_QUERY, - GET_APPLIED_MIGRATIONS_QUERY, GET_LAST_APPLIED_MIGRATION_QUERY, + insert_migration_query, verify_migrations, ASSERT_MIGRATIONS_SCHEMA_QUERY, + ASSERT_MIGRATIONS_TABLE_QUERY, GET_APPLIED_MIGRATIONS_QUERY, GET_LAST_APPLIED_MIGRATION_QUERY, }; use crate::{Error, Migration, Report, Target}; @@ -165,13 +165,27 @@ where grouped: bool, target: Target, migration_table_name: &str, + migration_table_schema: Option<&str>, ) -> Result { - self.execute(&[&Self::assert_migrations_table_query(migration_table_name)]) + let mut queries = Vec::with_capacity(1); + let migration_schema_query; + let migration_table_name = if let Some(schema) = migration_table_schema { + migration_schema_query = + ASSERT_MIGRATIONS_SCHEMA_QUERY.replace("%MIGRATION_TABLE_SCHEMA%", schema); + queries.push(migration_schema_query.as_str()); + format!(r#""{schema}"."{migration_table_name}""#) + } else { + migration_table_name.to_string() + }; + + let migration_table_query = Self::assert_migrations_table_query(&migration_table_name); + queries.push(migration_table_query.as_str()); + self.execute(&queries) .await .migration_err("error asserting migrations table", None)?; let applied_migrations = self - .get_applied_migrations(migration_table_name) + .get_applied_migrations(&migration_table_name) .await .migration_err("error getting current schema version", None)?; @@ -187,9 +201,9 @@ where } if grouped || matches!(target, Target::Fake | Target::FakeVersion(_)) { - migrate_grouped(self, migrations, target, migration_table_name).await + migrate_grouped(self, migrations, target, &migration_table_name).await } else { - migrate(self, migrations, target, migration_table_name).await + migrate(self, migrations, target, &migration_table_name).await } } } diff --git a/refinery_core/src/traits/mod.rs b/refinery_core/src/traits/mod.rs index d3eef6d3..c27a6f8a 100644 --- a/refinery_core/src/traits/mod.rs +++ b/refinery_core/src/traits/mod.rs @@ -110,6 +110,9 @@ pub(crate) const ASSERT_MIGRATIONS_TABLE_QUERY: &str = applied_on VARCHAR(255), checksum VARCHAR(255));"; +pub(crate) const ASSERT_MIGRATIONS_SCHEMA_QUERY: &str = + r#"CREATE SCHEMA IF NOT EXISTS "%MIGRATION_TABLE_SCHEMA%";"#; + pub(crate) const GET_APPLIED_MIGRATIONS_QUERY: &str = "SELECT version, name, applied_on, checksum \ FROM %MIGRATION_TABLE_NAME% ORDER BY version ASC;"; diff --git a/refinery_core/src/traits/sync.rs b/refinery_core/src/traits/sync.rs index 23cc7a90..1525f3e9 100644 --- a/refinery_core/src/traits/sync.rs +++ b/refinery_core/src/traits/sync.rs @@ -1,7 +1,7 @@ use crate::error::WrapMigrationError; use crate::traits::{ - insert_migration_query, verify_migrations, ASSERT_MIGRATIONS_TABLE_QUERY, - GET_APPLIED_MIGRATIONS_QUERY, GET_LAST_APPLIED_MIGRATION_QUERY, + insert_migration_query, verify_migrations, ASSERT_MIGRATIONS_SCHEMA_QUERY, + ASSERT_MIGRATIONS_TABLE_QUERY, GET_APPLIED_MIGRATIONS_QUERY, GET_LAST_APPLIED_MIGRATION_QUERY, }; use crate::{Error, Migration, Report, Target}; @@ -102,10 +102,27 @@ where GET_APPLIED_MIGRATIONS_QUERY.replace("%MIGRATION_TABLE_NAME%", migration_table_name) } - fn assert_migrations_table(&mut self, migration_table_name: &str) -> Result { + fn assert_migrations_table( + &mut self, + migration_table_name: &str, + migration_table_schema: Option<&str>, + ) -> Result { + let mut queries = Vec::with_capacity(1); + let assert_migrations_schema; + let migration_table_name = if let Some(schema) = migration_table_schema { + assert_migrations_schema = + ASSERT_MIGRATIONS_SCHEMA_QUERY.replace("%MIGRATION_TABLE_SCHEMA%", schema); + queries.push(assert_migrations_schema.as_str()); + format!(r#""{schema}"."{migration_table_name}""#) + } else { + migration_table_name.to_string() + }; + // Needed cause some database vendors like Mssql have a non sql standard way of checking the migrations table, // thou on this case it's just to be consistent with the async trait `AsyncMigrate` - self.execute(&[Self::assert_migrations_table_query(migration_table_name).as_str()]) + let assert_migrations_table = Self::assert_migrations_table_query(&migration_table_name); + queries.push(assert_migrations_table.as_str()); + self.execute(&queries) .migration_err("error asserting migrations table", None) } @@ -137,10 +154,14 @@ where abort_divergent: bool, abort_missing: bool, migration_table_name: &str, + migration_table_schema: Option<&str>, ) -> Result, Error> { - self.assert_migrations_table(migration_table_name)?; + self.assert_migrations_table(migration_table_name, migration_table_schema)?; - let applied_migrations = self.get_applied_migrations(migration_table_name)?; + let migration_table_name = migration_table_schema + .map(|schema| format!(r#""{schema}"."{migration_table_name}""#)) + .unwrap_or_else(|| migration_table_name.to_string()); + let applied_migrations = self.get_applied_migrations(&migration_table_name)?; let migrations = verify_migrations( applied_migrations, @@ -156,6 +177,7 @@ where Ok(migrations) } + #[allow(clippy::too_many_arguments)] fn migrate( &mut self, migrations: &[Migration], @@ -164,18 +186,23 @@ where grouped: bool, target: Target, migration_table_name: &str, + migration_table_schema: Option<&str>, ) -> Result { let migrations = self.get_unapplied_migrations( migrations, abort_divergent, abort_missing, migration_table_name, + migration_table_schema, )?; + let migration_table_name = migration_table_schema + .map(|schema| format!(r#""{schema}"."{migration_table_name}""#)) + .unwrap_or_else(|| migration_table_name.to_string()); if grouped || matches!(target, Target::Fake | Target::FakeVersion(_)) { - migrate(self, migrations, target, migration_table_name, true) + migrate(self, migrations, target, &migration_table_name, true) } else { - migrate(self, migrations, target, migration_table_name, false) + migrate(self, migrations, target, &migration_table_name, false) } } }