diff --git a/Cargo.lock b/Cargo.lock index 79906bb607..18e19ac08c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2757,9 +2757,9 @@ dependencies = [ [[package]] name = "quote" -version = "1.0.38" +version = "1.0.40" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0e4dccaaaf89514f546c693ddc140f729f958c247918a13380cccc6078391acc" +checksum = "1885c039570dc00dcb4ff087a89e185fd56bae234ddc7f056a945bf36467248d" dependencies = [ "proc-macro2", ] @@ -3468,6 +3468,7 @@ dependencies = [ "sha2", "smallvec", "sqlx", + "subst", "thiserror 2.0.11", "time", "tokio", @@ -3731,6 +3732,7 @@ dependencies = [ "quote", "sqlx-core", "sqlx-macros-core", + "subst", "syn 2.0.96", ] @@ -3752,6 +3754,7 @@ dependencies = [ "sqlx-mysql", "sqlx-postgres", "sqlx-sqlite", + "subst", "syn 2.0.96", "tokio", "url", @@ -3945,6 +3948,16 @@ dependencies = [ "syn 2.0.96", ] +[[package]] +name = "subst" +version = "0.3.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "33e7942675ea19db01ef8cf15a1e6443007208e6c74568bd64162da26d40160d" +dependencies = [ + "memchr", + "unicode-width 0.1.14", +] + [[package]] name = "subtle" version = "2.6.1" diff --git a/Cargo.toml b/Cargo.toml index 6d08df23d3..8d814f0ca9 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -185,6 +185,7 @@ uuid = "1.1.2" # Common utility crates dotenvy = { version = "0.15.7", default-features = false } +subst = "0.3.7" # Runtimes [workspace.dependencies.async-std] diff --git a/README.md b/README.md index f1e53cdced..9e882f300e 100644 --- a/README.md +++ b/README.md @@ -450,6 +450,26 @@ opt-level = 3 1 The `dotenv` crate itself appears abandoned as of [December 2021](https://github.com/dotenv-rs/dotenv/issues/74) so we now use the `dotenvy` crate instead. The file format is the same. +## Parameter Substitution for Migrations + +You can parameterize migrations using parameters, either from the environment or passed in from the cli or to the Migrator. + +For example: + +```sql +-- enable-substitution +CREATE USER ${USER_FROM_ENV} WITH PASSWORD ${PASSWORD_FROM_ENV} +-- disable-substituion +``` + +We use the [subst](https://crates.io/crates/subst) to support substitution. sqlx supports + +- Short format: `$NAME` +- Long format: `${NAME}` +- Default values: `${NAME:Bob}` +- Recursive Substitution in Default Values: `${NAME: Bob ${OTHER_NAME: and Alice}}` + + ## Safety This crate uses `#![forbid(unsafe_code)]` to ensure everything is implemented in 100% Safe Rust. diff --git a/sqlx-cli/README.md b/sqlx-cli/README.md index b20461b8fd..eabf68c101 100644 --- a/sqlx-cli/README.md +++ b/sqlx-cli/README.md @@ -65,6 +65,18 @@ any scripts that are still pending. --- +Users can also provide parameters through environment variables or pass them in manually. + +```bash +sqlx migrate run --params-from-env +``` + +```bash +sqlx migrate run --params key:value,key1,value1 +``` + +--- + Users can provide the directory for the migration scripts to `sqlx migrate` subcommands with the `--source` flag. ```bash @@ -105,6 +117,16 @@ Creating migrations/20211001154420_.up.sql Creating migrations/20211001154420_.down.sql ``` +Users can also provide parameters through environment variables or pass them in manually, just as they did with the run command. + +```bash +sqlx migrate revert --params-from-env +``` + +```bash +sqlx migrate revert --params key:value,key1,value1 +``` + ### Enable building in "offline mode" with `query!()` There are 2 steps to building with "offline mode": diff --git a/sqlx-cli/src/database.rs b/sqlx-cli/src/database.rs index eaba46eed9..b0cde2defa 100644 --- a/sqlx-cli/src/database.rs +++ b/sqlx-cli/src/database.rs @@ -62,7 +62,17 @@ pub async fn setup( connect_opts: &ConnectOpts, ) -> anyhow::Result<()> { create(connect_opts).await?; - migrate::run(config, migration_source, connect_opts, false, false, None).await + migrate::run( + config, + migration_source, + connect_opts, + false, + false, + None, + false, + Vec::new(), + ) + .await } async fn ask_to_continue_drop(db_url: String) -> bool { diff --git a/sqlx-cli/src/lib.rs b/sqlx-cli/src/lib.rs index e3f21c9863..068d4aeece 100644 --- a/sqlx-cli/src/lib.rs +++ b/sqlx-cli/src/lib.rs @@ -63,6 +63,8 @@ async fn do_run(opt: Opt) -> anyhow::Result<()> { ignore_missing, mut connect_opts, target_version, + params_from_env, + params, } => { let config = config.load_config().await?; @@ -75,6 +77,8 @@ async fn do_run(opt: Opt) -> anyhow::Result<()> { dry_run, *ignore_missing, target_version, + params_from_env, + params, ) .await? } @@ -85,6 +89,8 @@ async fn do_run(opt: Opt) -> anyhow::Result<()> { ignore_missing, mut connect_opts, target_version, + params_from_env, + params, } => { let config = config.load_config().await?; @@ -97,6 +103,8 @@ async fn do_run(opt: Opt) -> anyhow::Result<()> { dry_run, *ignore_missing, target_version, + params_from_env, + params, ) .await? } diff --git a/sqlx-cli/src/migrate.rs b/sqlx-cli/src/migrate.rs index 926e264032..b846ada18b 100644 --- a/sqlx-cli/src/migrate.rs +++ b/sqlx-cli/src/migrate.rs @@ -220,6 +220,8 @@ pub async fn run( dry_run: bool, ignore_missing: bool, target_version: Option, + params_from_env: bool, + parameters: Vec<(String, String)>, ) -> anyhow::Result<()> { let migrator = migration_source.resolve(config).await?; @@ -264,6 +266,14 @@ pub async fn run( .map(|m| (m.version, m)) .collect(); + let env_params: HashMap<_, _> = if params_from_env { + std::env::vars().collect() + } else { + HashMap::with_capacity(0) + }; + + let params: HashMap<_, _> = parameters.into_iter().collect(); + for migration in migrator.iter() { if migration.migration_type.is_down_migration() { // Skipping down migrations @@ -282,6 +292,18 @@ pub async fn run( let elapsed = if dry_run || skip { Duration::new(0, 0) + } else if params_from_env { + conn.apply( + config.migrate.table_name(), + &migration.process_parameters(&env_params)?, + ) + .await? + } else if !params.is_empty() { + conn.apply( + config.migrate.table_name(), + &migration.process_parameters(¶ms)?, + ) + .await? } else { conn.apply(config.migrate.table_name(), migration).await? }; @@ -322,6 +344,8 @@ pub async fn revert( dry_run: bool, ignore_missing: bool, target_version: Option, + params_from_env: bool, + parameters: Vec<(String, String)>, ) -> anyhow::Result<()> { let migrator = migration_source.resolve(config).await?; @@ -368,6 +392,15 @@ pub async fn revert( .collect(); let mut is_applied = false; + + let env_params: HashMap<_, _> = if params_from_env { + std::env::vars().collect() + } else { + HashMap::with_capacity(0) + }; + + let params: HashMap<_, _> = parameters.into_iter().collect(); + for migration in migrator.iter().rev() { if !migration.migration_type.is_down_migration() { // Skipping non down migration @@ -381,6 +414,18 @@ pub async fn revert( let elapsed = if dry_run || skip { Duration::new(0, 0) + } else if params_from_env { + conn.revert( + config.migrate.table_name(), + &migration.process_parameters(&env_params)?, + ) + .await? + } else if !params.is_empty() { + conn.revert( + config.migrate.table_name(), + &migration.process_parameters(¶ms)?, + ) + .await? } else { conn.revert(config.migrate.table_name(), migration).await? }; diff --git a/sqlx-cli/src/opt.rs b/sqlx-cli/src/opt.rs index cb09bc2ff5..63fa7143ab 100644 --- a/sqlx-cli/src/opt.rs +++ b/sqlx-cli/src/opt.rs @@ -10,6 +10,7 @@ use clap::{ use clap_complete::Shell; use sqlx::migrate::{MigrateError, Migrator, ResolveWith}; use std::env; +use std::error::Error; use std::ops::{Deref, Not}; use std::path::PathBuf; @@ -245,6 +246,15 @@ pub enum MigrateCommand { /// pending migrations. If already at the target version, then no-op. #[clap(long)] target_version: Option, + + #[clap(long)] + /// Template parameters for substitution in migrations from environment variables + params_from_env: bool, + + #[clap(long, short, value_parser = parse_key_val::, num_args = 1, value_delimiter=',')] + /// Provide template parameters for substitution in migrations, e.g. --params + /// key:value,key2:value2 + params: Vec<(String, String)>, }, /// Revert the latest migration with a down file. @@ -270,6 +280,15 @@ pub enum MigrateCommand { /// at the target version, then no-op. #[clap(long)] target_version: Option, + + #[clap(long)] + /// Template parameters for substitution in migrations from environment variables + params_from_env: bool, + + #[clap(long, short, value_parser = parse_key_val::, num_args = 1, value_delimiter=',')] + /// Provide template parameters for substitution in migrations, e.g. --params + /// key:value,key2:value2 + params: Vec<(String, String)>, }, /// List all available migrations. @@ -575,3 +594,17 @@ fn next_timestamp() -> String { fn fmt_sequential(version: i64) -> String { format!("{version:04}") } + +/// Parse a single key-value pair +fn parse_key_val(s: &str) -> Result<(T, U), Box> +where + T: std::str::FromStr, + T::Err: Error + Send + Sync + 'static, + U: std::str::FromStr, + U::Err: Error + Send + Sync + 'static, +{ + let pos = s + .find('=') + .ok_or_else(|| format!("invalid KEY=value: no `=` found in `{s}`"))?; + Ok((s[..pos].parse()?, s[pos + 1..].parse()?)) +} diff --git a/sqlx-core/Cargo.toml b/sqlx-core/Cargo.toml index 37cf9d3b91..52c4a2a567 100644 --- a/sqlx-core/Cargo.toml +++ b/sqlx-core/Cargo.toml @@ -91,6 +91,7 @@ hashlink = "0.10.0" indexmap = "2.0" event-listener = "5.2.0" hashbrown = "0.15.0" +subst = { workspace = true } [dev-dependencies] sqlx = { workspace = true, features = ["postgres", "sqlite", "mysql", "migrate", "macros", "time", "uuid"] } diff --git a/sqlx-core/src/migrate/error.rs b/sqlx-core/src/migrate/error.rs index a04243963a..5b1c74b722 100644 --- a/sqlx-core/src/migrate/error.rs +++ b/sqlx-core/src/migrate/error.rs @@ -40,6 +40,12 @@ pub enum MigrateError { )] Dirty(i64), + #[error("migration {0} was missing a parameter '{1}' at line {2}, column {3}")] + MissingParameter(String, String, usize, usize), + + #[error("Invalid parameter syntax {0}")] + InvalidParameterSyntax(String), + #[error("database driver does not support creation of schemas at migrate time: {0}")] CreateSchemasNotSupported(String), } diff --git a/sqlx-core/src/migrate/migration.rs b/sqlx-core/src/migrate/migration.rs index 79721d244d..e0960892b0 100644 --- a/sqlx-core/src/migrate/migration.rs +++ b/sqlx-core/src/migrate/migration.rs @@ -1,9 +1,13 @@ use sha2::{Digest, Sha384}; use std::borrow::Cow; +use std::collections::HashMap; -use crate::sql_str::SqlStr; +use crate::sql_str::{SqlSafeStr, SqlStr}; -use super::MigrationType; +use super::{MigrateError, MigrationType}; + +const ENABLE_SUBSTITUTION: &str = "-- enable-substitution"; +const DISABLE_SUBSTITUTION: &str = "-- disable-substitution"; #[derive(Debug, Clone)] pub struct Migration { @@ -52,6 +56,76 @@ impl Migration { no_tx, } } + + fn name(&self) -> String { + let description = self.description.replace(' ', "_"); + match self.migration_type { + MigrationType::Simple => { + format!("{}_{}", self.version, description) + } + MigrationType::ReversibleUp => { + format!("{}_{}.{}", self.version, description, "up") + } + MigrationType::ReversibleDown => { + format!("{}_{}.{}", self.version, description, "down") + } + } + } + + pub fn process_parameters( + &self, + params: &HashMap, + ) -> Result { + let Migration { + version, + description, + migration_type, + sql, + checksum, + no_tx, + } = self; + + let mut new_sql = String::with_capacity(sql.as_str().len()); + let mut substitution_enabled = false; + + for (i, line) in sql.as_str().lines().enumerate() { + if i != 0 { + new_sql.push('\n') + } + let trimmed_line = line.trim(); + if trimmed_line == ENABLE_SUBSTITUTION { + substitution_enabled = true; + new_sql.push_str(line); + continue; + } else if trimmed_line == DISABLE_SUBSTITUTION { + new_sql.push_str(line); + substitution_enabled = false; + continue; + } + + if substitution_enabled { + let substituted_line = subst::substitute(line, params).map_err(|e| match e { + subst::Error::NoSuchVariable(subst::error::NoSuchVariable { + position, + name, + }) => MigrateError::MissingParameter(self.name(), name, i + 1, position), + _ => MigrateError::InvalidParameterSyntax(e.to_string()), + })?; + new_sql.push_str(&substituted_line); + } else { + new_sql.push_str(line); + } + } + + Ok(Migration { + version: *version, + description: description.clone(), + migration_type: *migration_type, + sql: crate::sql_str::AssertSqlSafe(new_sql).into_sql_str(), + checksum: checksum.clone(), + no_tx: *no_tx, + }) + } } #[derive(Debug, Clone)] @@ -74,24 +148,154 @@ pub fn checksum_fragments<'a>(fragments: impl Iterator) -> Vec Result<(), MigrateError> { + const CREATE_USER: &str = r#" + -- enable-substitution + CREATE USER '${substitution_test_user}'; + -- disable-substitution + CREATE TABLE foo ( + id BIG SERIAL PRIMARY KEY + foo TEXT + ); + -- enable-substitution + DROP USER '${substitution_test_user}'; + -- disable-substitution + "#; + const EXPECTED_RESULT: &str = r#" + -- enable-substitution + CREATE USER 'my_user'; + -- disable-substitution + CREATE TABLE foo ( + id BIG SERIAL PRIMARY KEY + foo TEXT + ); + -- enable-substitution + DROP USER 'my_user'; + -- disable-substitution + "#; + + let migration = Migration::new( + 1, + Cow::Owned("test a simple parameter substitution".to_string()), + crate::migrate::MigrationType::Simple, + crate::sql_str::AssertSqlSafe(CREATE_USER.to_string()).into_sql_str(), + true, + ); + let result = migration.process_parameters(&HashMap::from([( + String::from("substitution_test_user"), + String::from("my_user"), + )]))?; + assert_eq!(result.sql, EXPECTED_RESULT); + Ok(()) + } + + #[test] + fn test_migration_process_parameters_no_substitution() -> Result<(), MigrateError> { + const CREATE_TABLE: &str = r#" + CREATE TABLE foo ( + id BIG SERIAL PRIMARY KEY + foo TEXT + ); + "#; + let migration = Migration::new( + 1, + std::borrow::Cow::Owned("test a simple parameter substitution".to_string()), + crate::migrate::MigrationType::Simple, + crate::sql_str::AssertSqlSafe(CREATE_TABLE.to_string()).into_sql_str(), + true, + ); + let result = migration.process_parameters(&HashMap::from([( + String::from("substitution_test_user"), + String::from("my_user"), + )]))?; + assert_eq!(result.sql, CREATE_TABLE); + Ok(()) + } + + #[test] + fn test_migration_process_parameters_missing_key() -> Result<(), MigrateError> { + const CREATE_TABLE: &str = r#" + -- enable-substitution + CREATE TABLE foo ( + id BIG SERIAL PRIMARY KEY + foo TEXT, + field ${TEST_MISSING_KEY} + ); + -- disable-substitution + + "#; + let migration = Migration::new( + 1, + Cow::Owned("test a simple parameter substitution".to_string()), + crate::migrate::MigrationType::Simple, + crate::sql_str::AssertSqlSafe(CREATE_TABLE.to_string()).into_sql_str(), + true, + ); + let Err(MigrateError::MissingParameter(..)) = + migration.process_parameters(&HashMap::with_capacity(0)) + else { + panic!("Missing env var not caught in process parameters missing env var") + }; + Ok(()) + } + + #[test] + fn test_migration_process_parameters_missing_key_with_default_value() -> Result<(), MigrateError> + { + const CREATE_TABLE: &str = r#" + -- enable-substitution + CREATE TABLE foo ( + id BIG SERIAL PRIMARY KEY + foo TEXT, + field ${TEST_MISSING_KEY:TEXT} + ); + -- disable-substitution + "#; + const EXPECTED_CREATE_TABLE: &str = r#" + -- enable-substitution + CREATE TABLE foo ( + id BIG SERIAL PRIMARY KEY + foo TEXT, + field TEXT + ); + -- disable-substitution + "#; + let migration = Migration::new( + 1, + Cow::Owned("test a simple parameter substitution".to_string()), + crate::migrate::MigrationType::Simple, + crate::sql_str::AssertSqlSafe(CREATE_TABLE.to_string()).into_sql_str(), + true, + ); + let result = migration.process_parameters(&HashMap::with_capacity(0))?; + assert_eq!(result.sql, EXPECTED_CREATE_TABLE); + Ok(()) + } + + #[test] + fn fragments_checksum_equals_full_checksum() { + // Copied from `examples/postgres/axum-social-with-tests/migrations/3_comment.sql` + let sql = "\ + \u{FEFF}create table comment (\r\n\ + \tcomment_id uuid primary key default gen_random_uuid(),\r\n\ + \tpost_id uuid not null references post(post_id),\r\n\ + \tuser_id uuid not null references \"user\"(user_id),\r\n\ + \tcontent text not null,\r\n\ + \tcreated_at timestamptz not null default now()\r\n\ + );\r\n\ + \r\n\ + create index on comment(post_id, created_at);\r\n\ + "; + + // Should yield a string for each character + let fragments_checksum = checksum_fragments(sql.split("")); + let full_checksum = checksum(sql); + + assert_eq!(fragments_checksum, full_checksum); + } } diff --git a/sqlx-core/src/migrate/migrator.rs b/sqlx-core/src/migrate/migrator.rs index 1ae4813106..c0e8f13f11 100644 --- a/sqlx-core/src/migrate/migrator.rs +++ b/sqlx-core/src/migrate/migrator.rs @@ -24,8 +24,11 @@ pub struct Migrator { #[doc(hidden)] pub no_tx: bool, #[doc(hidden)] + pub template_params: Option>, + #[doc(hidden)] + pub template_parameters_from_env: bool, + #[doc(hidden)] pub table_name: Cow<'static, str>, - #[doc(hidden)] pub create_schemas: Cow<'static, [Cow<'static, str>]>, } @@ -37,10 +40,35 @@ impl Migrator { ignore_missing: false, no_tx: false, locking: true, + template_params: None, + template_parameters_from_env: false, table_name: Cow::Borrowed("_sqlx_migrations"), create_schemas: Cow::Borrowed(&[]), }; + /// Set or update template parameters for migration placeholders. + /// + /// # Examples + /// + /// ```rust + /// # use sqlx_core::migrate::Migrator; + /// let mut migrator = Migrator::DEFAULT; + /// migrator.set_template_parameters(vec![("key", "value"), ("name", "test")]); + /// ``` + pub fn set_template_parameters(&mut self, params: I) -> &Self + where + I: IntoIterator, + K: Into, + V: Into, + { + let map: HashMap = params + .into_iter() + .map(|(k, v)| (k.into(), v.into())) + .collect(); + self.template_params = Some(map); + self + } + /// Creates a new instance with the given source. /// /// # Examples @@ -107,6 +135,15 @@ impl Migrator { self } + /// Specify whether template parameters for migrations should be read from the environment + pub fn set_template_parameters_from_env( + &mut self, + template_paramaters_from_env: bool, + ) -> &Self { + self.template_parameters_from_env = template_paramaters_from_env; + self + } + /// Specify whether or not to lock the database during migration. Defaults to `true`. /// /// ### Warning @@ -198,6 +235,12 @@ impl Migrator { .map(|m| (m.version, m)) .collect(); + let env_params = if self.template_parameters_from_env { + Some(std::env::vars().collect()) + } else { + None + }; + for migration in self.iter() { if target.is_some_and(|target| target < migration.version) { // Target version reached @@ -215,7 +258,18 @@ impl Migrator { } } None => { - conn.apply(&self.table_name, migration).await?; + if self.template_parameters_from_env { + conn.apply( + &self.table_name, + &migration.process_parameters(env_params.as_ref().unwrap())?, + ) + .await?; + } else if let Some(params) = self.template_params.as_ref() { + conn.apply(&self.table_name, &migration.process_parameters(params)?) + .await?; + } else { + conn.apply(&self.table_name, migration).await?; + } } } } @@ -275,6 +329,12 @@ impl Migrator { .map(|m| (m.version, m)) .collect(); + let env_params = if self.template_parameters_from_env { + Some(std::env::vars().collect()) + } else { + None + }; + for migration in self .iter() .rev() @@ -282,7 +342,18 @@ impl Migrator { .filter(|m| applied_migrations.contains_key(&m.version)) .filter(|m| m.version > target) { - conn.revert(&self.table_name, migration).await?; + if self.template_parameters_from_env { + conn.revert( + &self.table_name, + &migration.process_parameters(env_params.as_ref().unwrap())?, + ) + .await?; + } else if let Some(params) = self.template_params.as_ref() { + conn.revert(&self.table_name, &migration.process_parameters(params)?) + .await?; + } else { + conn.revert(&self.table_name, migration).await?; + } } // unlock the migrator to allow other migrators to run diff --git a/sqlx-macros-core/Cargo.toml b/sqlx-macros-core/Cargo.toml index 3c0b409113..350561787b 100644 --- a/sqlx-macros-core/Cargo.toml +++ b/sqlx-macros-core/Cargo.toml @@ -69,6 +69,7 @@ proc-macro2 = { version = "1.0.79", default-features = false } serde = { version = "1.0.132", features = ["derive"] } serde_json = { version = "1.0.73" } sha2 = { version = "0.10.0" } +subst = { workspace = true } syn = { version = "2.0.52", default-features = false, features = ["full", "derive", "parsing", "printing", "clone-impls"] } quote = { version = "1.0.26", default-features = false } url = { version = "2.2.2" } diff --git a/sqlx-macros-core/src/migrate.rs b/sqlx-macros-core/src/migrate.rs index b855703c22..ea7eb04a2c 100644 --- a/sqlx-macros-core/src/migrate.rs +++ b/sqlx-macros-core/src/migrate.rs @@ -1,6 +1,7 @@ #[cfg(any(sqlx_macros_unstable, procmacro2_semver_exempt))] extern crate proc_macro; +use std::collections::HashMap; use std::path::{Path, PathBuf}; use proc_macro2::{Span, TokenStream}; @@ -102,7 +103,39 @@ pub fn expand(path_arg: Option) -> crate::Result { expand_with_path(&config, &path) } +pub fn expand_migrator_from_lit_dir( + dir: LitStr, + parameters: Option>, +) -> crate::Result { + expand_migrator_from_dir(&dir.value(), dir.span(), parameters) +} + +pub(crate) fn expand_migrator_from_dir( + dir: &str, + err_span: proc_macro2::Span, + parameters: Option>, +) -> crate::Result { + let path = crate::common::resolve_path(dir, err_span)?; + expand_migrator(&path, parameters) +} + pub fn expand_with_path(config: &Config, path: &Path) -> crate::Result { + expand_migrator_with_config(config, path, None) +} + +pub(crate) fn expand_migrator( + path: &Path, + parameters: Option>, +) -> crate::Result { + let config = Config::try_from_crate_or_default()?; + expand_migrator_with_config(&config, path, parameters) +} + +pub(crate) fn expand_migrator_with_config( + config: &Config, + path: &Path, + parameters: Option>, +) -> crate::Result { let path = path.canonicalize().map_err(|e| { format!( "error canonicalizing migration directory {}: {e}", @@ -115,7 +148,14 @@ pub fn expand_with_path(config: &Config, path: &Path) -> crate::Result crate::Result { let path = crate::migrate::default_path(&config); - let resolved_path = crate::common::resolve_path(path, proc_macro2::Span::call_site())?; if resolved_path.is_dir() { diff --git a/sqlx-macros/Cargo.toml b/sqlx-macros/Cargo.toml index 1c3cd96bff..4371ecb06c 100644 --- a/sqlx-macros/Cargo.toml +++ b/sqlx-macros/Cargo.toml @@ -57,6 +57,7 @@ sqlx-macros-core = { workspace = true } proc-macro2 = { version = "1.0.36", default-features = false } syn = { version = "2.0.52", default-features = false, features = ["parsing", "proc-macro"] } quote = { version = "1.0.26", default-features = false } +subst.workspace = true [lints] workspace = true diff --git a/sqlx-macros/src/lib.rs b/sqlx-macros/src/lib.rs index ccffc9bd2a..7f85a06652 100644 --- a/sqlx-macros/src/lib.rs +++ b/sqlx-macros/src/lib.rs @@ -66,10 +66,122 @@ pub fn derive_from_row(input: TokenStream) -> TokenStream { #[cfg(feature = "migrate")] #[proc_macro] pub fn migrate(input: TokenStream) -> TokenStream { - use syn::LitStr; + use quote::quote; + use std::collections::HashMap; + use syn::{parse_macro_input, Expr, ExprArray, ExprLit, ExprPath, ExprTuple, Lit, LitStr}; - let input = syn::parse_macro_input!(input as Option); - match migrate::expand(input) { + // Extract directory path, handling both direct literals and grouped literals + fn extract_dir(expr: Option) -> LitStr { + match expr { + Some(Expr::Lit(ExprLit { + lit: Lit::Str(literal), + .. + })) => return literal, + Some(Expr::Group(group)) => { + if let Expr::Lit(ExprLit { + lit: Lit::Str(literal), + .. + }) = *group.expr + { + return literal; + } + } + _ => {} + } + panic!("Expected a string literal for the directory path."); + } + + // Extract a `String` value from an expression (either a string literal or a variable) + fn extract_value(expr: Expr, location: &str) -> String { + match expr { + Expr::Lit(ExprLit { + lit: Lit::Str(lit_str), + .. + }) => lit_str.value(), + Expr::Path(ExprPath { path, .. }) => path.segments.last().unwrap().ident.to_string(), + _ => panic!("Expected a string literal or a variable in {location}"), + } + } + + // Parse substitutions, expecting an array of tuples (String, Expr) + fn parse_substitutions(expr: Option) -> Option> { + let Expr::Group(group) = expr? else { + return None; + }; + let Expr::Array(ExprArray { elems, .. }) = *group.expr else { + panic!("Expected an array of tuples (String, Expr)."); + }; + + let mut map = HashMap::new(); + for elem in elems { + let Expr::Tuple(ExprTuple { + elems: tuple_elems, .. + }) = elem + else { + panic!("Expected a tuple (String, Expr). Got {:#?}", elem); + }; + + let mut tuple_elems = tuple_elems.into_iter(); + + let key = extract_value(tuple_elems.next().expect("Missing key in tuple."), "key"); + let value = extract_value( + tuple_elems.next().expect("Missing value in tuple."), + "value", + ); + map.insert(key, value); + } + Some(map) + } + + // Handle both the simple case (just path) and the tuple case (path + parameters) + let input_result: std::result::Result, syn::Error> = syn::parse(input.clone()); + if let Ok(simple_input) = input_result { + // Simple case: just a path or no arguments + return match migrate::expand(simple_input) { + Ok(ts) => ts.into(), + Err(e) => { + if let Some(parse_err) = e.downcast_ref::() { + parse_err.to_compile_error().into() + } else { + let msg = e.to_string(); + quote!(::std::compile_error!(#msg)).into() + } + } + }; + } + + // Complex case: parse tuple with parameters + let exp = parse_macro_input!(input as syn::Expr); + let (dir, parameters) = match exp { + Expr::Tuple(ExprTuple { elems, .. }) => { + let mut elems = elems.into_iter(); + (extract_dir(elems.next()), elems.next()) + } + Expr::Lit(ExprLit { + lit: Lit::Str(lit_str), + .. + }) => { + (lit_str, None) + } + Expr::Group(group) => { + if let Expr::Lit(ExprLit { + lit: Lit::Str(lit_str), + .. + }) = *group.expr + { + (lit_str, None) + } else { + panic!("Expected a tuple with directory path and optional parameters, or a string literal for the directory path."); + } + }, + _ => panic!( + "Expected a tuple with directory path and optional parameters, or a string literal for the directory path." + ), + }; + + // Parse substitutions and pass to migration expander + let substitutions = parse_substitutions(parameters); + match migrate::expand_migrator_from_lit_dir(dir, substitutions) { Ok(ts) => ts.into(), Err(e) => { if let Some(parse_err) = e.downcast_ref::() { diff --git a/src/macros/mod.rs b/src/macros/mod.rs index 0db6f0c2e7..366ef738f7 100644 --- a/src/macros/mod.rs +++ b/src/macros/mod.rs @@ -861,10 +861,15 @@ macro_rules! query_file_scalar_unchecked ( #[cfg(feature = "migrate")] #[macro_export] macro_rules! migrate { + ($directory:literal, parameters = $parameters:expr) => { + $crate::sqlx_macros::migrate!(($directory, $parameters)) + }; + (parameters = $parameters:expr) => { + $crate::sqlx_macros::migrate!(("./migrations", $parameters)) + }; ($dir:literal) => {{ $crate::sqlx_macros::migrate!($dir) }}; - () => {{ $crate::sqlx_macros::migrate!() }}; diff --git a/tests/ui-tests.rs b/tests/ui-tests.rs index 4a5ca240e1..fbdd14aa6f 100644 --- a/tests/ui-tests.rs +++ b/tests/ui-tests.rs @@ -44,3 +44,12 @@ fn ui_tests() { t.compile_fail("tests/ui/*.rs"); } + +#[test] +fn ui_migrate_tests() { + if cfg!(feature = "migrate") { + let t = trybuild::TestCases::new(); + t.compile_fail("tests/ui/migrate/invalid_key.rs"); + t.compile_fail("tests/ui/migrate/missing_parameter.rs"); + } +} diff --git a/tests/ui/migrate/invalid_key.rs b/tests/ui/migrate/invalid_key.rs new file mode 100644 index 0000000000..3edfaa8aa0 --- /dev/null +++ b/tests/ui/migrate/invalid_key.rs @@ -0,0 +1,4 @@ +fn main() { + //Fails due to invalid key + sqlx::migrate!("foo", parameters = [(123, "foo")]); +} diff --git a/tests/ui/migrate/invalid_key.stderr b/tests/ui/migrate/invalid_key.stderr new file mode 100644 index 0000000000..f92b52d11b --- /dev/null +++ b/tests/ui/migrate/invalid_key.stderr @@ -0,0 +1,8 @@ +error: proc macro panicked + --> tests/ui/migrate/invalid_key.rs:3:5 + | +3 | sqlx::migrate!("foo", parameters = [(123, "foo")]); + | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + | + = help: message: Expected a string literal or a variable in key + = note: this error originates in the macro `sqlx::migrate` (in Nightly builds, run with -Z macro-backtrace for more info) diff --git a/tests/ui/migrate/migrations/20250423195520_create_users.down.sql b/tests/ui/migrate/migrations/20250423195520_create_users.down.sql new file mode 100644 index 0000000000..d2f607c5b8 --- /dev/null +++ b/tests/ui/migrate/migrations/20250423195520_create_users.down.sql @@ -0,0 +1 @@ +-- Add down migration script here diff --git a/tests/ui/migrate/migrations/20250423195520_create_users.up.sql b/tests/ui/migrate/migrations/20250423195520_create_users.up.sql new file mode 100644 index 0000000000..fc82f2cb9c --- /dev/null +++ b/tests/ui/migrate/migrations/20250423195520_create_users.up.sql @@ -0,0 +1,4 @@ +-- Add up migration script here +-- enable-substitution +CREATE USER ${my_user} WITH ENCRYPTED PASSWORD '${my_password}' INHERIT; +-- disable-substitution diff --git a/tests/ui/migrate/missing_parameter.rs b/tests/ui/migrate/missing_parameter.rs new file mode 100644 index 0000000000..a440b829ce --- /dev/null +++ b/tests/ui/migrate/missing_parameter.rs @@ -0,0 +1,5 @@ +fn main() { + //Fails due to missing migration parameter + let _shaggy = "shaggy"; + sqlx::migrate!("../../../../tests/ui/migrate/migrations", parameters = [("my_user", "scooby"), ("fooby", _shaggy)]); +} diff --git a/tests/ui/migrate/missing_parameter.stderr b/tests/ui/migrate/missing_parameter.stderr new file mode 100644 index 0000000000..0717bc64c4 --- /dev/null +++ b/tests/ui/migrate/missing_parameter.stderr @@ -0,0 +1,8 @@ +error: proc macro panicked + --> tests/ui/migrate/missing_parameter.rs:4:5 + | +4 | sqlx::migrate!("../../../../tests/ui/migrate/migrations", parameters = [("my_user", "scooby"), ("fooby", _shaggy)]); + | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + | + = help: message: Error processing parameters: migration 20250423195520_create_users.up was missing a parameter 'my_password' at line 3, column 50 + = note: this error originates in the macro `sqlx::migrate` (in Nightly builds, run with -Z macro-backtrace for more info)