diff --git a/Cargo.lock b/Cargo.lock index fd2e6d0d7..d1cc452a8 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -314,6 +314,7 @@ dependencies = [ "cw-storage-plus", "schemars", "serde", + "thiserror", ] [[package]] diff --git a/packages/cw2/Cargo.toml b/packages/cw2/Cargo.toml index 4e7854d41..e61642142 100644 --- a/packages/cw2/Cargo.toml +++ b/packages/cw2/Cargo.toml @@ -14,3 +14,4 @@ cosmwasm-std = { version = "1.0.1", default-features = false } cw-storage-plus = "1.0.1" schemars = "0.8.1" serde = { version = "1.0.1", default-features = false, features = ["derive"] } +thiserror = "1.0.23" diff --git a/packages/cw2/src/lib.rs b/packages/cw2/src/lib.rs index aae498609..201b77ae8 100644 --- a/packages/cw2/src/lib.rs +++ b/packages/cw2/src/lib.rs @@ -18,8 +18,11 @@ For more information on this specification, please check out the */ use cosmwasm_schema::cw_serde; -use cosmwasm_std::{CustomQuery, QuerierWrapper, QueryRequest, StdResult, Storage, WasmQuery}; +use cosmwasm_std::{ + CustomQuery, QuerierWrapper, QueryRequest, StdError, StdResult, Storage, WasmQuery, +}; use cw_storage_plus::Item; +use thiserror::Error; pub const CONTRACT: Item = Item::new("contract_info"); @@ -35,6 +38,51 @@ pub struct ContractVersion { pub version: String, } +#[derive(Error, Debug, PartialEq)] +pub enum VersionError { + #[error(transparent)] + Std(#[from] StdError), + + #[error("Contract version info not found")] + NotFound, + + #[error("Wrong contract: expecting `{expected}`, found `{found}`")] + WrongContract { expected: String, found: String }, + + #[error("Wrong contract version: expecting `{expected}`, found `{found}`")] + WrongVersion { expected: String, found: String }, +} + +/// Assert that the stored contract version info matches the given value. +/// This is useful during migrations, for making sure that the correct contract +/// is being migrated, and it's being migrated from the correct version. +pub fn assert_contract_version( + storage: &dyn Storage, + expected_contract: &str, + expected_version: &str, +) -> Result<(), VersionError> { + let ContractVersion { contract, version } = match CONTRACT.may_load(storage)? { + Some(contract) => contract, + None => return Err(VersionError::NotFound), + }; + + if contract != expected_contract { + return Err(VersionError::WrongContract { + expected: expected_contract.into(), + found: contract, + }); + } + + if version != expected_version { + return Err(VersionError::WrongVersion { + expected: expected_version.into(), + found: version, + }); + } + + Ok(()) +} + /// get_contract_version can be use in migrate to read the previous version of this contract pub fn get_contract_version(store: &dyn Storage) -> StdResult { CONTRACT.load(store) @@ -98,4 +146,44 @@ mod tests { }; assert_eq!(expected, loaded); } + + #[test] + fn assert_work() { + let mut store = MockStorage::new(); + + const EXPECTED_CONTRACT: &str = "crate:mars-red-bank"; + const EXPECTED_VERSION: &str = "1.0.0"; + + // error if contract version is not set + let err = assert_contract_version(&store, EXPECTED_CONTRACT, EXPECTED_VERSION).unwrap_err(); + assert_eq!(err, VersionError::NotFound); + + // wrong contract name + let wrong_contract = "crate:cw20-base"; + set_contract_version(&mut store, wrong_contract, EXPECTED_VERSION).unwrap(); + let err = assert_contract_version(&store, EXPECTED_CONTRACT, EXPECTED_VERSION).unwrap_err(); + assert_eq!( + err, + VersionError::WrongContract { + expected: EXPECTED_CONTRACT.into(), + found: wrong_contract.into() + }, + ); + + // wrong contract version + let wrong_version = "8.8.8"; + set_contract_version(&mut store, EXPECTED_CONTRACT, wrong_version).unwrap(); + let err = assert_contract_version(&store, EXPECTED_CONTRACT, EXPECTED_VERSION).unwrap_err(); + assert_eq!( + err, + VersionError::WrongVersion { + expected: EXPECTED_VERSION.into(), + found: wrong_version.into() + }, + ); + + // correct name and version + set_contract_version(&mut store, EXPECTED_CONTRACT, EXPECTED_VERSION).unwrap(); + assert!(assert_contract_version(&store, EXPECTED_CONTRACT, EXPECTED_VERSION).is_ok()); + } }