Skip to content

Add Selene to test compiled hugr #8

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 49 additions & 0 deletions .github/workflows/pull-request.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
name: Pull Request

on:
push:
branches:
- main
pull_request:
branches:
- '**'
merge_group:
types: [checks_requested]

env:
SCCACHE_GHA_ENABLED: "true"
UV_VERSION: '0.7.19'

jobs:
check:
name: Check Python (${{ matrix.python-version }})
runs-on: ubuntu-latest
env:
PYTHON_VERSION: ${{ matrix.python-version }}

strategy:
matrix:
python-version: [ '3.10', '3.12' ]

steps:
- uses: actions/checkout@v4
- name: Run sccache-cache
uses: mozilla-actions/[email protected]

- name: Install UV
uses: astral-sh/setup-uv@v5
with:
version: ${{ env.UV_VERSION }}
enable-cache: true

- name: Install qcorrect
run: uv sync --locked --all-extras --dev

- name: Check formatting with ruff
run: uv run ruff format --check src

- name: Lint with ruff
run: uv run ruff check src

- name: Run tests
run: uv run pytest
Empty file added examples/__init__.py
Empty file.
63 changes: 32 additions & 31 deletions examples/example.py
Original file line number Diff line number Diff line change
@@ -1,55 +1,56 @@
from typing import Generic
from collections.abc import Callable
from typing import Generic, no_type_check

from guppylang import guppy
from guppylang.decorator import guppy
from guppylang.std import quantum as phys
from guppylang.std.builtins import array, comptime, nat, owned

import qcorrect as qct

# Define a new codeblock type with parameter `N`
# Define logical code block
N = guppy.nat_var("N")

class ExampleCode:

@qct.type(copyable=False, droppable=False)
class CodeBlock(Generic[N]):
data_qs: array[phys.qubit, N]
@qct.type(copyable=False, droppable=False)
class CodeBlock(Generic[N]):
data_qs: array[phys.qubit, N]

# Define logical operations
@qct.operation
def zero(n: nat @ comptime) -> "CodeBlock[n]":
return CodeBlock(array(phys.qubit() for _ in range(n)))

@qct.operation
def measure(q: CodeBlock[N] @ owned) -> array[bool, N]:
return phys.measure_array(q.data_qs)
# Define code operations
class CodeDef(qct.CodeDefinition):
def __init__(self, n: nat):
self.n: nat = n

@qct.operation
def zero(self) -> Callable:
@guppy
@no_type_check
def circuit() -> "CodeBlock[comptime(self.n)]":
return CodeBlock(array(phys.qubit() for _ in range(comptime(self.n))))

return circuit

# Define a new codeblock type with parameter `N`
N = guppy.nat_var("N")
@qct.operation
def measure(self) -> Callable:
@guppy
@no_type_check
def circuit(
q: "CodeBlock[comptime(self.n)] @ owned",
) -> "array[bool, comptime(self.n)]":
return phys.measure_array(q.data_qs)

@qct.type(copyable=False, droppable=False)
class CodeBlock(Generic[N]):
data_qs: array[phys.qubit, N]
return circuit

class ExampleCode:

# Define logical operations
@qct.operation
def zero(n: nat @ comptime) -> "CodeBlock[n]":
return CodeBlock(array(phys.qubit() for _ in range(n)))
# Create code instance and get guppy module
code = CodeDef(5).get_module()

@qct.operation
def measure(q: CodeBlock[N] @ owned) -> array[bool, N]:
return phys.measure_array(q.data_qs)
# Define a code instance
code = ExampleCode()

# Use new code to create a hugr module
# Write logical guppy program
@guppy
def main() -> None:
q = code.zero(6)
q = code.zero()
code.measure(q)


hugr = main.compile()
4 changes: 3 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,4 +25,6 @@ lint = [
]
test = [
"ipykernel>=6.29.5",
]
"pytest>=8.4.1",
"selene-hugr-qis-compiler>=0.2.0rc1",
]
9 changes: 4 additions & 5 deletions ruff.toml
Original file line number Diff line number Diff line change
Expand Up @@ -74,11 +74,10 @@ ignore = [
]

[lint.per-file-ignores]
"guppy/ast_util.py" = ["B009", "B010"]
"guppy/decorator.py" = ["B010"]
"tests/integration/*" = ["F841", "C416", "RUF005"]
"tests/{hugr,integration}/*" = ["B", "FBT", "SIM", "I"]
"__init__.py" = ["F401"] # module imported but unused
"src/qcorrect/decorator.py" = ["B023"]
"tests/*" = ["F841", "C416", "RUF005"]
"examples/*" = ["F821"]
"__init__.py" = ["F401"]

# [pydocstyle]
# convention = "google"
Expand Down
2 changes: 1 addition & 1 deletion src/qcorrect/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from qcorrect.decorator import operation, type
from qcorrect.decorator import CodeDefinition, operation, type
94 changes: 70 additions & 24 deletions src/qcorrect/decorator.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,29 @@
import inspect
from types import ModuleType

from guppylang.checker.core import Globals
from guppylang.checker.func_checker import check_signature
from guppylang.decorator import custom_guppy_decorator, get_calling_frame, guppy
from guppylang.definition.common import DefId
from guppylang.definition.function import parse_py_func
from guppylang.definition.struct import RawStructDef
from guppylang.definition.value import CallableDef
from guppylang.engine import DEF_STORE, ENGINE
from guppylang.std._internal.util import quantum_op
from guppylang.tys.subst import Inst
from hugr import ops
from hugr import tys as ht
from hugr.ext import ExplicitBound, Extension, OpDef, OpDefSig, TypeDef
from pydantic_extra_types.semantic_version import SemanticVersion

# This is required to get parsed function definitions
ENGINE.reset()

# Currently used to store hugr extensions for types
# Should be moved to the `code` decorator
hugr_ext = Extension("qcorrect", SemanticVersion(0, 1, 0))


@custom_guppy_decorator
def type(copyable: bool = True, droppable: bool = True):
"""Decorator to define code types"""
frame = get_calling_frame()

def wrapper(cls):
Expand All @@ -27,7 +33,7 @@ def wrapper(cls):

type_def = TypeDef(
name=cls.__name__,
description=cls.__doc__,
description=cls.__doc__ or "",
params=[],
bound=ExplicitBound(ht.TypeBound.Any),
)
Expand All @@ -48,23 +54,63 @@ def wrapper(cls):

@custom_guppy_decorator
def operation(defn):
guppy_dec = guppy.declare(defn)
func_ast, _ = parse_py_func(
DEF_STORE.raw_defs[guppy_dec.id].python_func, DEF_STORE.sources
)
ty = check_signature(func_ast, Globals(DEF_STORE.frames[guppy_dec.id]))

op_def = OpDef(
name=defn.__name__,
description="",
signature=OpDefSig(poly_func=ty.to_hugr_poly()),
lower_funcs=[],
)

hugr_ext.add_op_def(op_def)

def empty_dec() -> None: ...

return guppy.hugr_op(
quantum_op(defn.__name__, ext=hugr_ext), name=defn.__name__, signature=ty
)(empty_dec)
"""Decorator to define code operations"""

defn.__setattr__("_qct_op", True)

return defn


class CodeDefinition:
guppy_module: ModuleType
hugr_ext: Extension

@custom_guppy_decorator
def get_module(self) -> ModuleType:
self.guppy_module = ModuleType(self.__class__.__name__)
self.hugr_ext = Extension(self.__class__.__name__, SemanticVersion(0, 1, 0))
self.inner_defs = {}

# Get all `inner` operations
for name, defn in inspect.getmembers(self, predicate=inspect.ismethod):
if hasattr(defn, "_qct_op"):
self.inner_defs[name] = defn()

# Define `outer` operations
guppy_def = self.inner_defs[name]
parsed_def = ENGINE.get_parsed(guppy_def.id)

assert isinstance(parsed_def, CallableDef)

ty = parsed_def.ty

op_def = OpDef(
name=name,
description=defn.__doc__ or "",
signature=OpDefSig(poly_func=ty.to_hugr_poly()),
lower_funcs=[
# FixedHugr(
# extensions=ht.ExtensionSet(),
# hugr=compiled_def.package.to_str(),
# )
],
)

self.hugr_ext.add_op_def(op_def)

def empty_dec() -> None: ...

def hugr_op(op_def):
def op(ty: ht.FunctionType, inst: Inst) -> ops.DataflowOp:
return ops.ExtOp(op_def, ty)

return op

guppy_op = guppy.hugr_op(
hugr_op(op_def),
name=name,
signature=ty,
)(empty_dec)
self.guppy_module.__setattr__(name, guppy_op)

return self.guppy_module
Empty file added tests/__init__.py
Empty file.
79 changes: 79 additions & 0 deletions tests/test_codedef.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
from collections.abc import Callable
from typing import Generic

import pytest
from guppylang.decorator import get_calling_frame, guppy
from guppylang.error import GuppyError, GuppyTypeError
from guppylang.std import quantum as phys
from guppylang.std.builtins import array, comptime, nat, owned
from selene_hugr_qis_compiler import check_hugr

import qcorrect as qct

N = guppy.nat_var("N")


@qct.type(copyable=False, droppable=False)
class CodeBlock(Generic[N]):
data_qs: array[phys.qubit, N]


class CodeDef(qct.CodeDefinition):
def __init__(self, n: nat):
self.n: nat = n
self.frame = get_calling_frame()

@qct.operation
def zero(self) -> Callable:
@guppy
def circuit() -> "CodeBlock[comptime(self.n)]":
return CodeBlock(array(phys.qubit() for _ in range(comptime(self.n))))

return circuit

@qct.operation
def measure(self) -> Callable:
@guppy
def circuit(
q: "CodeBlock[comptime(self.n)] @ owned",
) -> "array[bool, comptime(self.n)]":
return phys.measure_array(q.data_qs)

return circuit


def test_code_usage():
code = CodeDef(5).get_module()

@guppy
def main() -> None:
q = code.zero()
code.measure(q)

hugr = main.compile()

check_hugr(hugr.package.to_bytes())


def test_mismatched_codes():
code4 = CodeDef(4).get_module()
code5 = CodeDef(5).get_module()

@guppy
def main() -> None:
q = code4.zero()
code5.measure(q)

with pytest.raises(GuppyTypeError):
main.compile()


def test_block_dropped():
code = CodeDef(5).get_module()

@guppy
def main() -> None:
q = code.zero()

with pytest.raises(GuppyError):
main.compile()
Loading