Skip to content

feat: traits for transforming Types/TypeArgs/etc. #1991

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Mar 19, 2025
Merged
216 changes: 213 additions & 3 deletions hugr-core/src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,15 @@ impl SumType {
}
}

impl Transformable for SumType {
fn transform<T: TypeTransformer>(&mut self, tr: &T) -> Result<bool, T::Err> {
match self {
SumType::Unit { .. } => Ok(false),
SumType::General { rows } => rows.transform(tr),
}
}
}

impl<RV: MaybeRV> From<SumType> for TypeBase<RV> {
fn from(sum: SumType) -> Self {
match sum {
Expand Down Expand Up @@ -530,6 +539,36 @@ impl<RV: MaybeRV> TypeBase<RV> {
}
}

impl<RV: MaybeRV> Transformable for TypeBase<RV> {
fn transform<T: TypeTransformer>(&mut self, tr: &T) -> Result<bool, T::Err> {
match &mut self.0 {
TypeEnum::Alias(_) | TypeEnum::RowVar(_) | TypeEnum::Variable(..) => Ok(false),
TypeEnum::Extension(custom_type) => {
Ok(if let Some(nt) = tr.apply_custom(custom_type)? {
*self = nt.into_();
true
} else {
let args_changed = custom_type.args_mut().transform(tr)?;
if args_changed {
*self = Self::new_extension(
custom_type
.get_type_def(&custom_type.get_extension()?)?
Copy link
Collaborator

@doug-q doug-q Mar 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I haven't looked at the other callsites of get_type_def, but I don't like it. I would prefer to inline it here, and just "pub(crate)"-icise get_extension. If you disagree I'm fine with this.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd look at it the other way - how about removing get_extension by inlining into get_type_def and "pub(crate)"-icise only get_type_def?

Copy link
Contributor Author

@acl-cqc acl-cqc Mar 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are three callers of get_type_def and they all call get_extension first with only a slight change in how they handle errors from get_extension (i.e. one caller panics....on errors from either, so that's fine to combine both errors)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree that makes sense, but I do think get_extension is a perfectly reasonable thing to want to do. Up to you.

Copy link
Contributor Author

@acl-cqc acl-cqc Mar 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OIC. The thing is that the caller has to hang onto the result Arc from get_extension for the lifetime of the &TypeDef. Ugh. I guess if Extension::get_type returned an Option<Arc<TypeDef>> rather than an Option<&TypeDef> then this would go away but otherwise we're a bit stuck.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I don't think the implementation of get_type_def is great either, but it seems mad to make the caller do that work (looking up the CustomType's type-id in the extension). Could change it to an Option rather than Result if that's preferable but the SignatureError is appropriate in at least a couple of cases, so I'm inclined to leave it.

#2001 covers what I think is the way forward but no need to do that in this PR.

.instantiate(custom_type.args())?,
);
}
args_changed
})
}
TypeEnum::Function(fty) => fty.transform(tr),
TypeEnum::Sum(sum_type) => {
let ch = sum_type.transform(tr)?;
self.1 = self.0.least_upper_bound();
Ok(ch)
}
}
}
}

impl Type {
fn substitute1(&self, s: &Substitution) -> Self {
let v = self.substitute(s);
Expand Down Expand Up @@ -666,6 +705,53 @@ impl<'a> Substitution<'a> {
}
}

/// A transformation that can be applied to a [Type] or [TypeArg].
/// More general in some ways than a Substitution: can fail with a
/// [Self::Err], may change [TypeBound::Copyable] to [TypeBound::Any],
/// and applies to arbitrary extension types rather than type variables.
pub trait TypeTransformer {
/// Error returned when a [CustomType] cannot be transformed, or a type
/// containing it (e.g. if changing a [TypeArg::Type] from copyable to
/// linear invalidates a parameterized type).
type Err: std::error::Error + From<SignatureError>;

/// Applies the transformation to an extension type.
///
/// Note that if the [CustomType] has type arguments, these will *not*
/// have been transformed first (this might not produce a valid type
/// due to changes in [TypeBound]).
///
/// Returns a type to use instead, or None to indicate no change
/// (in which case, the TypeArgs will be transformed instead.
/// To prevent transforming the arguments, return `t.clone().into()`.)
fn apply_custom(&self, t: &CustomType) -> Result<Option<Type>, Self::Err>;

// Note: in future releases more methods may be added here to transform other types.
// By defaulting such trait methods to Ok(None), backwards compatibility will be preserved.
}

/// Trait for things that can be transformed by applying a [TypeTransformer].
/// (A destructive / in-place mutation.)
pub trait Transformable {
/// Applies a [TypeTransformer] to this instance.
///
/// Returns true if any part may have changed, or false for definitely no change.
///
/// If an Err occurs, `self` may be left in an inconsistent state (e.g. partially
/// transformed).
fn transform<T: TypeTransformer>(&mut self, t: &T) -> Result<bool, T::Err>;
}

impl<E: Transformable> Transformable for [E] {
fn transform<T: TypeTransformer>(&mut self, tr: &T) -> Result<bool, T::Err> {
let mut any_change = false;
for item in self {
any_change |= item.transform(tr)?;
}
Ok(any_change)
}
}

pub(crate) fn check_typevar_decl(
decls: &[TypeParam],
idx: usize,
Expand Down Expand Up @@ -693,12 +779,15 @@ pub(crate) fn check_typevar_decl(

#[cfg(test)]
pub(crate) mod test {

use std::sync::Weak;

use super::*;
use crate::extension::prelude::usize_t;
use crate::type_row;
use crate::extension::prelude::{qb_t, usize_t};
use crate::extension::TypeDefBound;
use crate::std_extensions::collections::array::{array_type, array_type_parametric};
use crate::std_extensions::collections::list::list_type;
use crate::types::type_param::TypeArgError;
use crate::{hugr::IdentList, type_row, Extension};

#[test]
fn construct() {
Expand Down Expand Up @@ -756,6 +845,127 @@ pub(crate) mod test {
}
}

pub(super) struct FnTransformer<T>(pub(super) T);
impl<T: Fn(&CustomType) -> Option<Type>> TypeTransformer for FnTransformer<T> {
type Err = SignatureError;

fn apply_custom(&self, t: &CustomType) -> Result<Option<Type>, Self::Err> {
Ok((self.0)(t))
}
}
#[test]
fn transform() {
const LIN: SmolStr = SmolStr::new_inline("MyLinear");
let e = Extension::new_test_arc(IdentList::new("TestExt").unwrap(), |e, w| {
e.add_type(LIN, vec![], String::new(), TypeDefBound::any(), w)
.unwrap();
});
let lin = e.get_type(&LIN).unwrap().instantiate([]).unwrap();

let lin_to_usize = FnTransformer(|ct: &CustomType| (*ct == lin).then_some(usize_t()));
let mut t = Type::new_extension(lin.clone());
assert_eq!(t.transform(&lin_to_usize), Ok(true));
assert_eq!(t, usize_t());

for coln in [
list_type,
|t| array_type(10, t),
|t| {
array_type_parametric(
TypeArg::new_var_use(0, TypeParam::bounded_nat(3.try_into().unwrap())),
t,
)
.unwrap()
},
] {
let mut t = coln(lin.clone().into());
assert_eq!(t.transform(&lin_to_usize), Ok(true));
let expected = coln(usize_t());
assert_eq!(t, expected);
assert_eq!(t.transform(&lin_to_usize), Ok(false));
assert_eq!(t, expected);
}
}

#[test]
fn transform_copyable_to_linear() {
const CPY: SmolStr = SmolStr::new_inline("MyCopyable");
const COLN: SmolStr = SmolStr::new_inline("ColnOfCopyableElems");
let e = Extension::new_test_arc(IdentList::new("TestExt").unwrap(), |e, w| {
e.add_type(CPY, vec![], String::new(), TypeDefBound::copyable(), w)
.unwrap();
e.add_type(
COLN,
vec![TypeParam::new_list(TypeBound::Copyable)],
String::new(),
TypeDefBound::copyable(),
w,
)
.unwrap();
});

let cpy = e.get_type(&CPY).unwrap().instantiate([]).unwrap();
let mk_opt = |t: Type| Type::new_sum([type_row![], TypeRow::from(t)]);

let cpy_to_qb = FnTransformer(|ct: &CustomType| (ct == &cpy).then_some(qb_t()));

let mut t = mk_opt(cpy.clone().into());
assert_eq!(t.transform(&cpy_to_qb), Ok(true));
assert_eq!(t, mk_opt(qb_t()));

let coln = e.get_type(&COLN).unwrap();
let c_of_cpy = coln
.instantiate([TypeArg::Sequence {
elems: vec![Type::from(cpy.clone()).into()],
}])
.unwrap();

let mut t = Type::new_extension(c_of_cpy.clone());
assert_eq!(
t.transform(&cpy_to_qb),
Err(SignatureError::from(TypeArgError::TypeMismatch {
param: TypeBound::Copyable.into(),
arg: qb_t().into()
}))
);

let mut t = Type::new_extension(
coln.instantiate([TypeArg::Sequence {
elems: vec![mk_opt(Type::from(cpy.clone())).into()],
}])
.unwrap(),
);
assert_eq!(
t.transform(&cpy_to_qb),
Err(SignatureError::from(TypeArgError::TypeMismatch {
param: TypeBound::Copyable.into(),
arg: mk_opt(qb_t()).into()
}))
);

// Finally, check handling Coln<Cpy> overrides handling of Cpy
let cpy_to_qb2 = FnTransformer(|ct: &CustomType| {
assert_ne!(ct, &cpy);
(ct == &c_of_cpy).then_some(usize_t())
});
let mut t = Type::new_extension(
coln.instantiate([TypeArg::Sequence {
elems: vec![Type::from(c_of_cpy.clone()).into(); 2],
}])
.unwrap(),
);
assert_eq!(t.transform(&cpy_to_qb2), Ok(true));
assert_eq!(
t,
Type::new_extension(
coln.instantiate([TypeArg::Sequence {
elems: vec![usize_t().into(); 2]
}])
.unwrap()
)
);
}

mod proptest {

use crate::proptest::RecursionDepth;
Expand Down
7 changes: 5 additions & 2 deletions hugr-core/src/types/custom.rs
Original file line number Diff line number Diff line change
Expand Up @@ -86,15 +86,18 @@ impl CustomType {
def.check_custom(self)
}

fn get_type_def<'a>(&self, ext: &'a Arc<Extension>) -> Result<&'a TypeDef, SignatureError> {
pub(super) fn get_type_def<'a>(
&self,
ext: &'a Arc<Extension>,
) -> Result<&'a TypeDef, SignatureError> {
ext.get_type(&self.id)
.ok_or(SignatureError::ExtensionTypeNotFound {
exn: self.extension.clone(),
typ: self.id.clone(),
})
}

fn get_extension(&self) -> Result<Arc<Extension>, SignatureError> {
pub(super) fn get_extension(&self) -> Result<Arc<Extension>, SignatureError> {
self.extension_ref
.upgrade()
.ok_or(SignatureError::MissingTypeExtension {
Expand Down
42 changes: 40 additions & 2 deletions hugr-core/src/types/signature.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@ use std::fmt::{self, Display};

use super::type_param::TypeParam;
use super::type_row::TypeRowBase;
use super::{MaybeRV, NoRV, RowVariable, Substitution, Type, TypeRow};
use super::{
MaybeRV, NoRV, RowVariable, Substitution, Transformable, Type, TypeRow, TypeTransformer,
};

use crate::core::PortIndex;
use crate::extension::resolution::{
Expand Down Expand Up @@ -142,6 +144,13 @@ impl<RV: MaybeRV> FuncTypeBase<RV> {
}
}

impl<RV: MaybeRV> Transformable for FuncTypeBase<RV> {
fn transform<T: TypeTransformer>(&mut self, tr: &T) -> Result<bool, T::Err> {
// TODO handle extension sets?
Ok(self.input.transform(tr)? | self.output.transform(tr)?)
}
}

impl FuncValueType {
/// If this FuncValueType contains any row variables, return one.
pub fn find_rowvar(&self) -> Option<RowVariable> {
Expand Down Expand Up @@ -330,7 +339,9 @@ impl<RV1: MaybeRV, RV2: MaybeRV> PartialEq<FuncTypeBase<RV1>> for Cow<'_, FuncTy

#[cfg(test)]
mod test {
use crate::{extension::prelude::usize_t, type_row};
use crate::extension::prelude::{bool_t, qb_t, usize_t};
use crate::type_row;
use crate::types::{test::FnTransformer, CustomType, TypeEnum};

use super::*;
#[test]
Expand Down Expand Up @@ -358,4 +369,31 @@ mod test {
(&type_row![Type::UNIT], &vec![usize_t()].into())
);
}

#[test]
fn test_transform() {
let TypeEnum::Extension(usz_t) = usize_t().as_type_enum().clone() else {
panic!()
};
let tr = FnTransformer(|ct: &CustomType| (ct == &usz_t).then_some(bool_t()));
let row_with = || TypeRow::from(vec![usize_t(), qb_t(), bool_t()]);
let row_after = || TypeRow::from(vec![bool_t(), qb_t(), bool_t()]);
let mut sig = Signature::new(row_with(), row_after());
let exp = Signature::new(row_after(), row_after());
assert_eq!(sig.transform(&tr), Ok(true));
assert_eq!(sig, exp);
assert_eq!(sig.transform(&tr), Ok(false));
assert_eq!(sig, exp);
let exp = Type::new_function(exp);
for fty in [
FuncValueType::new(row_after(), row_with()),
FuncValueType::new(row_with(), row_with()),
] {
let mut t = Type::new_function(fty);
assert_eq!(t.transform(&tr), Ok(true));
assert_eq!(t, exp);
assert_eq!(t.transform(&tr), Ok(false));
assert_eq!(t, exp);
}
}
}
18 changes: 17 additions & 1 deletion hugr-core/src/types/type_param.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,10 @@ use std::num::NonZeroU64;
use thiserror::Error;

use super::row_var::MaybeRV;
use super::{check_typevar_decl, NoRV, RowVariable, Substitution, Type, TypeBase, TypeBound};
use super::{
check_typevar_decl, NoRV, RowVariable, Substitution, Transformable, Type, TypeBase, TypeBound,
TypeTransformer,
};
use crate::extension::ExtensionSet;
use crate::extension::SignatureError;

Expand Down Expand Up @@ -369,6 +372,19 @@ impl TypeArg {
}
}

impl Transformable for TypeArg {
fn transform<T: TypeTransformer>(&mut self, tr: &T) -> Result<bool, T::Err> {
match self {
TypeArg::Type { ty } => ty.transform(tr),
TypeArg::Sequence { elems } => elems.transform(tr),
TypeArg::BoundedNat { .. }
| TypeArg::String { .. }
| TypeArg::Extensions { .. }
| TypeArg::Variable { .. } => Ok(false),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This case should be tested

}
}
}

impl TypeArgVariable {
/// Return the index.
pub fn index(&self) -> usize {
Expand Down
11 changes: 10 additions & 1 deletion hugr-core/src/types/type_row.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,10 @@ use std::{
ops::{Deref, DerefMut},
};

use super::{type_param::TypeParam, MaybeRV, NoRV, RowVariable, Substitution, Type, TypeBase};
use super::{
type_param::TypeParam, MaybeRV, NoRV, RowVariable, Substitution, Transformable, Type, TypeBase,
TypeTransformer,
};
use crate::{extension::SignatureError, utils::display_list};
use delegate::delegate;
use itertools::Itertools;
Expand Down Expand Up @@ -96,6 +99,12 @@ impl<RV: MaybeRV> TypeRowBase<RV> {
}
}

impl<RV: MaybeRV> Transformable for TypeRowBase<RV> {
fn transform<T: TypeTransformer>(&mut self, tr: &T) -> Result<bool, T::Err> {
self.to_mut().transform(tr)
}
}

impl TypeRow {
delegate! {
to self.types {
Expand Down
Loading