diff --git a/hugr-llvm/src/extension/int.rs b/hugr-llvm/src/extension/int.rs index d9bd051bf..3b34d3630 100644 --- a/hugr-llvm/src/extension/int.rs +++ b/hugr-llvm/src/extension/int.rs @@ -281,6 +281,22 @@ fn emit_int_op<'c, H: HugrView>( .as_basic_value_enum()]) }), IntOpDef::ipow => emit_ipow(context, args), + // Type args are width of input, width of output + IntOpDef::iwiden_u => emit_custom_unary_op(context, args, |ctx, arg, outs| { + let [out] = outs.try_into()?; + Ok(vec![ctx + .builder() + .build_int_cast_sign_flag(arg.into_int_value(), out.into_int_type(), false, "")? + .as_basic_value_enum()]) + }), + IntOpDef::iwiden_s => emit_custom_unary_op(context, args, |ctx, arg, outs| { + let [out] = outs.try_into()?; + + Ok(vec![ctx + .builder() + .build_int_cast_sign_flag(arg.into_int_value(), out.into_int_type(), true, "")? + .as_basic_value_enum()]) + }), _ => Err(anyhow!("IntOpEmitter: unimplemented op: {}", op.name())), } } @@ -346,6 +362,7 @@ mod test { use hugr_core::{ builder::{Dataflow, DataflowSubContainer}, extension::prelude::bool_t, + ops::ExtensionOp, std_extensions::arithmetic::{ int_ops, int_types::{ConstInt, INT_TYPES}, @@ -362,18 +379,25 @@ mod test { test::{exec_ctx, llvm_ctx, TestContext}, }; - fn test_binary_int_op(name: impl AsRef, log_width: u8) -> Hugr { + // Instantiate an extension op which takes one width argument + fn make_int_op(name: impl AsRef, log_width: u8) -> ExtensionOp { + int_ops::EXTENSION + .instantiate_extension_op(name.as_ref(), [(log_width as u64).into()]) + .unwrap() + } + + fn test_binary_int_op(ext_op: ExtensionOp, log_width: u8) -> Hugr { let ty = &INT_TYPES[log_width as usize]; - test_int_op_with_results::<2>(name, log_width, None, ty.clone()) + test_int_op_with_results::<2>(ext_op, log_width, None, ty.clone()) } - fn test_binary_icmp_op(name: impl AsRef, log_width: u8) -> Hugr { - test_int_op_with_results::<2>(name, log_width, None, bool_t()) + fn test_binary_icmp_op(ext_op: ExtensionOp, log_width: u8) -> Hugr { + test_int_op_with_results::<2>(ext_op, log_width, None, bool_t()) } fn test_int_op_with_results( // N is the number of inputs to the hugr - name: impl AsRef, + ext_op: ExtensionOp, log_width: u8, inputs: Option<[ConstInt; N]>, // If inputs are provided, they'll be wired into the op, otherwise the inputs to the hugr will be wired into the op output_type: Type, @@ -400,9 +424,6 @@ mod test { input_wires } }; - let ext_op = int_ops::EXTENSION - .instantiate_extension_op(name.as_ref(), [(log_width as u64).into()]) - .unwrap(); let outputs = hugr_builder .add_dataflow_op(ext_op, input_wires) .unwrap() @@ -415,7 +436,8 @@ mod test { fn test_neg_emission(mut llvm_ctx: TestContext) { llvm_ctx.add_extensions(add_int_extensions); let ty = INT_TYPES[2].clone(); - let hugr = test_int_op_with_results::<1>("ineg", 2, None, ty.clone()); + let ext_op = make_int_op("ineg", 2); + let hugr = test_int_op_with_results::<1>(ext_op, 2, None, ty.clone()); check_emission!("ineg", hugr, llvm_ctx); } @@ -425,16 +447,38 @@ mod test { #[case::ipow("ipow", 3)] fn test_binop_emission(mut llvm_ctx: TestContext, #[case] op: String, #[case] width: u8) { llvm_ctx.add_extensions(add_int_extensions); - let hugr = test_binary_int_op(op.clone(), width); + let ext_op = make_int_op(op.clone(), width); + let hugr = test_binary_int_op(ext_op, width); check_emission!(op.clone(), hugr, llvm_ctx); } + #[rstest] + #[case::signed_2_3("iwiden_s", 2, 3)] + #[case::signed_1_6("iwiden_s", 1, 6)] + #[case::unsigned_2_3("iwiden_u", 2, 3)] + #[case::unsigned_1_6("iwiden_u", 1, 6)] + fn test_widen_emission( + mut llvm_ctx: TestContext, + #[case] op: String, + #[case] from: u8, + #[case] to: u8, + ) { + llvm_ctx.add_extensions(add_int_extensions); + let out_ty = INT_TYPES[to as usize].clone(); + let ext_op = int_ops::EXTENSION + .instantiate_extension_op(&op, [(from as u64).into(), (to as u64).into()]) + .unwrap(); + let hugr = test_int_op_with_results::<1>(ext_op, from, None, out_ty.into()); + + check_emission!(op.clone(), hugr, llvm_ctx); + } #[rstest] #[case::ieq("ieq", 1)] #[case::ilt_s("ilt_s", 0)] fn test_cmp_emission(mut llvm_ctx: TestContext, #[case] op: String, #[case] width: u8) { llvm_ctx.add_extensions(add_int_extensions); - let hugr = test_binary_icmp_op(op.clone(), width); + let ext_op = make_int_op(op.clone(), width); + let hugr = test_binary_icmp_op(ext_op, width); check_emission!(op.clone(), hugr, llvm_ctx); } @@ -473,7 +517,9 @@ mod test { ConstInt::new_u(6, lhs).unwrap(), ConstInt::new_u(6, rhs).unwrap(), ]; - let hugr = test_int_op_with_results::<2>(op, 6, Some(inputs), ty.clone()); + let ext_op = make_int_op(&op, 6); + + let hugr = test_int_op_with_results::<2>(ext_op, 6, Some(inputs), ty.clone()); assert_eq!(exec_ctx.exec_hugr_u64(hugr, "main"), result); } @@ -506,7 +552,9 @@ mod test { ConstInt::new_s(6, lhs).unwrap(), ConstInt::new_s(6, rhs).unwrap(), ]; - let hugr = test_int_op_with_results::<2>(op, 6, Some(inputs), ty.clone()); + let ext_op = make_int_op(&op, 6); + + let hugr = test_int_op_with_results::<2>(ext_op, 6, Some(inputs), ty.clone()); assert_eq!(exec_ctx.exec_hugr_i64(hugr, "main"), result); } @@ -522,7 +570,9 @@ mod test { exec_ctx.add_extensions(add_int_extensions); let input = ConstInt::new_s(6, arg).unwrap(); let ty = INT_TYPES[6].clone(); - let hugr = test_int_op_with_results::<1>(op, 6, Some([input]), ty.clone()); + let ext_op = make_int_op(&op, 6); + + let hugr = test_int_op_with_results::<1>(ext_op, 6, Some([input]), ty.clone()); assert_eq!(exec_ctx.exec_hugr_i64(hugr, "main"), result); } @@ -539,7 +589,9 @@ mod test { exec_ctx.add_extensions(add_int_extensions); let input = ConstInt::new_u(6, arg).unwrap(); let ty = INT_TYPES[6].clone(); - let hugr = test_int_op_with_results::<1>(op, 6, Some([input]), ty.clone()); + let ext_op = make_int_op(&op, 6); + + let hugr = test_int_op_with_results::<1>(ext_op, 6, Some([input]), ty.clone()); assert_eq!(exec_ctx.exec_hugr_u64(hugr, "main"), result); } } diff --git a/hugr-llvm/src/extension/snapshots/hugr_llvm__extension__int__test__inarrow_u@pre-mem2reg@llvm14.snap b/hugr-llvm/src/extension/snapshots/hugr_llvm__extension__int__test__inarrow_u@pre-mem2reg@llvm14.snap new file mode 100644 index 000000000..c7a1210e0 --- /dev/null +++ b/hugr-llvm/src/extension/snapshots/hugr_llvm__extension__int__test__inarrow_u@pre-mem2reg@llvm14.snap @@ -0,0 +1,28 @@ +--- +source: hugr-llvm/src/extension/int.rs +expression: mod_str +--- +; ModuleID = 'test_context' +source_filename = "test_context" + +@0 = private unnamed_addr constant [25 x i8] c"Can't narrow into bounds\00", align 1 + +define { i1, { i32, i8* }, i64 } @_hl.main.1(i64 %0) { +alloca_block: + %"0" = alloca { i1, { i32, i8* }, i64 }, align 8 + %"2_0" = alloca i64, align 8 + %"4_0" = alloca { i1, { i32, i8* }, i64 }, align 8 + br label %entry_block + +entry_block: ; preds = %alloca_block + store i64 %0, i64* %"2_0", align 4 + %"2_01" = load i64, i64* %"2_0", align 4 + %bounds_check = icmp ugt i64 %"2_01", -1 + %1 = insertvalue { i1, { i32, i8* }, i64 } { i1 true, { i32, i8* } poison, i64 poison }, i64 %"2_01", 2 + %2 = select i1 %bounds_check, { i1, { i32, i8* }, i64 } { i1 false, { i32, i8* } { i32 2, i8* getelementptr inbounds ([25 x i8], [25 x i8]* @0, i32 0, i32 0) }, i64 poison }, { i1, { i32, i8* }, i64 } %1 + store { i1, { i32, i8* }, i64 } %2, { i1, { i32, i8* }, i64 }* %"4_0", align 8 + %"4_02" = load { i1, { i32, i8* }, i64 }, { i1, { i32, i8* }, i64 }* %"4_0", align 8 + store { i1, { i32, i8* }, i64 } %"4_02", { i1, { i32, i8* }, i64 }* %"0", align 8 + %"03" = load { i1, { i32, i8* }, i64 }, { i1, { i32, i8* }, i64 }* %"0", align 8 + ret { i1, { i32, i8* }, i64 } %"03" +} diff --git a/hugr-llvm/src/extension/snapshots/hugr_llvm__extension__int__test__iwiden_s@llvm14.snap b/hugr-llvm/src/extension/snapshots/hugr_llvm__extension__int__test__iwiden_s@llvm14.snap new file mode 100644 index 000000000..413829754 --- /dev/null +++ b/hugr-llvm/src/extension/snapshots/hugr_llvm__extension__int__test__iwiden_s@llvm14.snap @@ -0,0 +1,15 @@ +--- +source: hugr-llvm/src/extension/int.rs +expression: mod_str +--- +; ModuleID = 'test_context' +source_filename = "test_context" + +define i64 @_hl.main.1(i8 %0) { +alloca_block: + br label %entry_block + +entry_block: ; preds = %alloca_block + %1 = sext i8 %0 to i64 + ret i64 %1 +} diff --git a/hugr-llvm/src/extension/snapshots/hugr_llvm__extension__int__test__iwiden_s@pre-mem2reg@llvm14.snap b/hugr-llvm/src/extension/snapshots/hugr_llvm__extension__int__test__iwiden_s@pre-mem2reg@llvm14.snap new file mode 100644 index 000000000..fcb34323d --- /dev/null +++ b/hugr-llvm/src/extension/snapshots/hugr_llvm__extension__int__test__iwiden_s@pre-mem2reg@llvm14.snap @@ -0,0 +1,24 @@ +--- +source: hugr-llvm/src/extension/int.rs +expression: mod_str +--- +; ModuleID = 'test_context' +source_filename = "test_context" + +define i64 @_hl.main.1(i8 %0) { +alloca_block: + %"0" = alloca i64, align 8 + %"2_0" = alloca i8, align 1 + %"4_0" = alloca i64, align 8 + br label %entry_block + +entry_block: ; preds = %alloca_block + store i8 %0, i8* %"2_0", align 1 + %"2_01" = load i8, i8* %"2_0", align 1 + %1 = sext i8 %"2_01" to i64 + store i64 %1, i64* %"4_0", align 4 + %"4_02" = load i64, i64* %"4_0", align 4 + store i64 %"4_02", i64* %"0", align 4 + %"03" = load i64, i64* %"0", align 4 + ret i64 %"03" +} diff --git a/hugr-llvm/src/extension/snapshots/hugr_llvm__extension__int__test__iwiden_u@llvm14.snap b/hugr-llvm/src/extension/snapshots/hugr_llvm__extension__int__test__iwiden_u@llvm14.snap new file mode 100644 index 000000000..76b9aa90d --- /dev/null +++ b/hugr-llvm/src/extension/snapshots/hugr_llvm__extension__int__test__iwiden_u@llvm14.snap @@ -0,0 +1,14 @@ +--- +source: hugr-llvm/src/extension/int.rs +expression: mod_str +--- +; ModuleID = 'test_context' +source_filename = "test_context" + +define i8 @_hl.main.1(i8 %0) { +alloca_block: + br label %entry_block + +entry_block: ; preds = %alloca_block + ret i8 %0 +} diff --git a/hugr-llvm/src/extension/snapshots/hugr_llvm__extension__int__test__iwiden_u@pre-mem2reg@llvm14.snap b/hugr-llvm/src/extension/snapshots/hugr_llvm__extension__int__test__iwiden_u@pre-mem2reg@llvm14.snap new file mode 100644 index 000000000..0b60c0fe3 --- /dev/null +++ b/hugr-llvm/src/extension/snapshots/hugr_llvm__extension__int__test__iwiden_u@pre-mem2reg@llvm14.snap @@ -0,0 +1,23 @@ +--- +source: hugr-llvm/src/extension/int.rs +expression: mod_str +--- +; ModuleID = 'test_context' +source_filename = "test_context" + +define i8 @_hl.main.1(i8 %0) { +alloca_block: + %"0" = alloca i8, align 1 + %"2_0" = alloca i8, align 1 + %"4_0" = alloca i8, align 1 + br label %entry_block + +entry_block: ; preds = %alloca_block + store i8 %0, i8* %"2_0", align 1 + %"2_01" = load i8, i8* %"2_0", align 1 + store i8 %"2_01", i8* %"4_0", align 1 + %"4_02" = load i8, i8* %"4_0", align 1 + store i8 %"4_02", i8* %"0", align 1 + %"03" = load i8, i8* %"0", align 1 + ret i8 %"03" +}