Skip to content

feat: Emit widen ops from the int ops extension #1946

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Mar 5, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 67 additions & 15 deletions hugr-llvm/src/extension/int.rs
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,22 @@ fn emit_int_op<'c, H: HugrView<Node = Node>>(
.as_basic_value_enum()])
}),
IntOpDef::ipow => emit_ipow(context, args),
// Type args are width of input, width of output
IntOpDef::iwiden_u => emit_custom_unary_op(context, args, |ctx, arg, outs| {
let [out] = outs.try_into()?;
Ok(vec![ctx
.builder()
.build_int_cast_sign_flag(arg.into_int_value(), out.into_int_type(), false, "")?
.as_basic_value_enum()])
}),
IntOpDef::iwiden_s => emit_custom_unary_op(context, args, |ctx, arg, outs| {
let [out] = outs.try_into()?;

Ok(vec![ctx
.builder()
.build_int_cast_sign_flag(arg.into_int_value(), out.into_int_type(), true, "")?
.as_basic_value_enum()])
}),
_ => Err(anyhow!("IntOpEmitter: unimplemented op: {}", op.name())),
}
}
Expand Down Expand Up @@ -346,6 +362,7 @@ mod test {
use hugr_core::{
builder::{Dataflow, DataflowSubContainer},
extension::prelude::bool_t,
ops::ExtensionOp,
std_extensions::arithmetic::{
int_ops,
int_types::{ConstInt, INT_TYPES},
Expand All @@ -362,18 +379,25 @@ mod test {
test::{exec_ctx, llvm_ctx, TestContext},
};

fn test_binary_int_op(name: impl AsRef<str>, log_width: u8) -> Hugr {
// Instantiate an extension op which takes one width argument
fn make_int_op(name: impl AsRef<str>, log_width: u8) -> ExtensionOp {
int_ops::EXTENSION
.instantiate_extension_op(name.as_ref(), [(log_width as u64).into()])
.unwrap()
}

fn test_binary_int_op(ext_op: ExtensionOp, log_width: u8) -> Hugr {
let ty = &INT_TYPES[log_width as usize];
test_int_op_with_results::<2>(name, log_width, None, ty.clone())
test_int_op_with_results::<2>(ext_op, log_width, None, ty.clone())
}

fn test_binary_icmp_op(name: impl AsRef<str>, log_width: u8) -> Hugr {
test_int_op_with_results::<2>(name, log_width, None, bool_t())
fn test_binary_icmp_op(ext_op: ExtensionOp, log_width: u8) -> Hugr {
test_int_op_with_results::<2>(ext_op, log_width, None, bool_t())
}

fn test_int_op_with_results<const N: usize>(
// N is the number of inputs to the hugr
name: impl AsRef<str>,
ext_op: ExtensionOp,
log_width: u8,
inputs: Option<[ConstInt; N]>, // If inputs are provided, they'll be wired into the op, otherwise the inputs to the hugr will be wired into the op
output_type: Type,
Expand All @@ -400,9 +424,6 @@ mod test {
input_wires
}
};
let ext_op = int_ops::EXTENSION
.instantiate_extension_op(name.as_ref(), [(log_width as u64).into()])
.unwrap();
let outputs = hugr_builder
.add_dataflow_op(ext_op, input_wires)
.unwrap()
Expand All @@ -415,7 +436,8 @@ mod test {
fn test_neg_emission(mut llvm_ctx: TestContext) {
llvm_ctx.add_extensions(add_int_extensions);
let ty = INT_TYPES[2].clone();
let hugr = test_int_op_with_results::<1>("ineg", 2, None, ty.clone());
let ext_op = make_int_op("ineg", 2);
let hugr = test_int_op_with_results::<1>(ext_op, 2, None, ty.clone());
check_emission!("ineg", hugr, llvm_ctx);
}

Expand All @@ -425,16 +447,38 @@ mod test {
#[case::ipow("ipow", 3)]
fn test_binop_emission(mut llvm_ctx: TestContext, #[case] op: String, #[case] width: u8) {
llvm_ctx.add_extensions(add_int_extensions);
let hugr = test_binary_int_op(op.clone(), width);
let ext_op = make_int_op(op.clone(), width);
let hugr = test_binary_int_op(ext_op, width);
check_emission!(op.clone(), hugr, llvm_ctx);
}

#[rstest]
#[case::signed_2_3("iwiden_s", 2, 3)]
#[case::signed_1_6("iwiden_s", 1, 6)]
#[case::unsigned_2_3("iwiden_u", 2, 3)]
#[case::unsigned_1_6("iwiden_u", 1, 6)]
fn test_widen_emission(
mut llvm_ctx: TestContext,
#[case] op: String,
#[case] from: u8,
#[case] to: u8,
) {
llvm_ctx.add_extensions(add_int_extensions);
let out_ty = INT_TYPES[to as usize].clone();
let ext_op = int_ops::EXTENSION
.instantiate_extension_op(&op, [(from as u64).into(), (to as u64).into()])
.unwrap();
let hugr = test_int_op_with_results::<1>(ext_op, from, None, out_ty.into());

check_emission!(op.clone(), hugr, llvm_ctx);
}
#[rstest]
#[case::ieq("ieq", 1)]
#[case::ilt_s("ilt_s", 0)]
fn test_cmp_emission(mut llvm_ctx: TestContext, #[case] op: String, #[case] width: u8) {
llvm_ctx.add_extensions(add_int_extensions);
let hugr = test_binary_icmp_op(op.clone(), width);
let ext_op = make_int_op(op.clone(), width);
let hugr = test_binary_icmp_op(ext_op, width);
check_emission!(op.clone(), hugr, llvm_ctx);
}

Expand Down Expand Up @@ -473,7 +517,9 @@ mod test {
ConstInt::new_u(6, lhs).unwrap(),
ConstInt::new_u(6, rhs).unwrap(),
];
let hugr = test_int_op_with_results::<2>(op, 6, Some(inputs), ty.clone());
let ext_op = make_int_op(&op, 6);

let hugr = test_int_op_with_results::<2>(ext_op, 6, Some(inputs), ty.clone());
assert_eq!(exec_ctx.exec_hugr_u64(hugr, "main"), result);
}

Expand Down Expand Up @@ -506,7 +552,9 @@ mod test {
ConstInt::new_s(6, lhs).unwrap(),
ConstInt::new_s(6, rhs).unwrap(),
];
let hugr = test_int_op_with_results::<2>(op, 6, Some(inputs), ty.clone());
let ext_op = make_int_op(&op, 6);

let hugr = test_int_op_with_results::<2>(ext_op, 6, Some(inputs), ty.clone());
assert_eq!(exec_ctx.exec_hugr_i64(hugr, "main"), result);
}

Expand All @@ -522,7 +570,9 @@ mod test {
exec_ctx.add_extensions(add_int_extensions);
let input = ConstInt::new_s(6, arg).unwrap();
let ty = INT_TYPES[6].clone();
let hugr = test_int_op_with_results::<1>(op, 6, Some([input]), ty.clone());
let ext_op = make_int_op(&op, 6);

let hugr = test_int_op_with_results::<1>(ext_op, 6, Some([input]), ty.clone());
assert_eq!(exec_ctx.exec_hugr_i64(hugr, "main"), result);
}

Expand All @@ -539,7 +589,9 @@ mod test {
exec_ctx.add_extensions(add_int_extensions);
let input = ConstInt::new_u(6, arg).unwrap();
let ty = INT_TYPES[6].clone();
let hugr = test_int_op_with_results::<1>(op, 6, Some([input]), ty.clone());
let ext_op = make_int_op(&op, 6);

let hugr = test_int_op_with_results::<1>(ext_op, 6, Some([input]), ty.clone());
assert_eq!(exec_ctx.exec_hugr_u64(hugr, "main"), result);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
---
source: hugr-llvm/src/extension/int.rs
expression: mod_str
---
; ModuleID = 'test_context'
source_filename = "test_context"

@0 = private unnamed_addr constant [25 x i8] c"Can't narrow into bounds\00", align 1

define { i1, { i32, i8* }, i64 } @_hl.main.1(i64 %0) {
alloca_block:
%"0" = alloca { i1, { i32, i8* }, i64 }, align 8
%"2_0" = alloca i64, align 8
%"4_0" = alloca { i1, { i32, i8* }, i64 }, align 8
br label %entry_block

entry_block: ; preds = %alloca_block
store i64 %0, i64* %"2_0", align 4
%"2_01" = load i64, i64* %"2_0", align 4
%bounds_check = icmp ugt i64 %"2_01", -1
%1 = insertvalue { i1, { i32, i8* }, i64 } { i1 true, { i32, i8* } poison, i64 poison }, i64 %"2_01", 2
%2 = select i1 %bounds_check, { i1, { i32, i8* }, i64 } { i1 false, { i32, i8* } { i32 2, i8* getelementptr inbounds ([25 x i8], [25 x i8]* @0, i32 0, i32 0) }, i64 poison }, { i1, { i32, i8* }, i64 } %1
store { i1, { i32, i8* }, i64 } %2, { i1, { i32, i8* }, i64 }* %"4_0", align 8
%"4_02" = load { i1, { i32, i8* }, i64 }, { i1, { i32, i8* }, i64 }* %"4_0", align 8
store { i1, { i32, i8* }, i64 } %"4_02", { i1, { i32, i8* }, i64 }* %"0", align 8
%"03" = load { i1, { i32, i8* }, i64 }, { i1, { i32, i8* }, i64 }* %"0", align 8
ret { i1, { i32, i8* }, i64 } %"03"
}
15 changes: 15 additions & 0 deletions hugr-llvm/src/extension/snapshots/[email protected]
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
---
source: hugr-llvm/src/extension/int.rs
expression: mod_str
---
; ModuleID = 'test_context'
source_filename = "test_context"

define i64 @_hl.main.1(i8 %0) {
alloca_block:
br label %entry_block

entry_block: ; preds = %alloca_block
%1 = sext i8 %0 to i64
ret i64 %1
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
---
source: hugr-llvm/src/extension/int.rs
expression: mod_str
---
; ModuleID = 'test_context'
source_filename = "test_context"

define i64 @_hl.main.1(i8 %0) {
alloca_block:
%"0" = alloca i64, align 8
%"2_0" = alloca i8, align 1
%"4_0" = alloca i64, align 8
br label %entry_block

entry_block: ; preds = %alloca_block
store i8 %0, i8* %"2_0", align 1
%"2_01" = load i8, i8* %"2_0", align 1
%1 = sext i8 %"2_01" to i64
store i64 %1, i64* %"4_0", align 4
%"4_02" = load i64, i64* %"4_0", align 4
store i64 %"4_02", i64* %"0", align 4
%"03" = load i64, i64* %"0", align 4
ret i64 %"03"
}
14 changes: 14 additions & 0 deletions hugr-llvm/src/extension/snapshots/[email protected]
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
---
source: hugr-llvm/src/extension/int.rs
expression: mod_str
---
; ModuleID = 'test_context'
source_filename = "test_context"

define i8 @_hl.main.1(i8 %0) {
alloca_block:
br label %entry_block

entry_block: ; preds = %alloca_block
ret i8 %0
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
---
source: hugr-llvm/src/extension/int.rs
expression: mod_str
---
; ModuleID = 'test_context'
source_filename = "test_context"

define i8 @_hl.main.1(i8 %0) {
alloca_block:
%"0" = alloca i8, align 1
%"2_0" = alloca i8, align 1
%"4_0" = alloca i8, align 1
br label %entry_block

entry_block: ; preds = %alloca_block
store i8 %0, i8* %"2_0", align 1
%"2_01" = load i8, i8* %"2_0", align 1
store i8 %"2_01", i8* %"4_0", align 1
%"4_02" = load i8, i8* %"4_0", align 1
store i8 %"4_02", i8* %"0", align 1
%"03" = load i8, i8* %"0", align 1
ret i8 %"03"
}
Loading