Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 22 additions & 1 deletion frontends/PyCDE/src/pycde/signals.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from .circt import ir

from contextvars import ContextVar
from functools import singledispatchmethod
from functools import cached_property, singledispatchmethod
from typing import Callable, Dict, Iterable, List, Optional, Tuple, Union
import re
import numpy as np
Expand Down Expand Up @@ -708,6 +708,27 @@ def build(self):
# All the work is done in the metaclass.


class UnionSignal(Signal):

@cached_property
def field_indices(self) -> Dict[str, int]:
return {
name: idx for idx, (name, _, _) in enumerate(self.type.strip.fields)
}

def __getitem__(self, sub):
if sub not in self.field_indices:
raise LookupError(f"Union field '{sub}' not found in {self.type}")
from .dialects import hw
with get_user_loc():
return hw.UnionExtractOp(self.value, self.field_indices[sub])

def __getattr__(self, attr):
if attr not in self.field_indices:
raise AttributeError(f"{type(self)} object has no attribute '{attr}'")
return self.__getitem__(attr)


class ChannelSignal(Signal):

def reg(self, clk, rst=None, name=None):
Expand Down
95 changes: 95 additions & 0 deletions frontends/PyCDE/src/pycde/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,8 @@ def _FromCirctType(type: typing.Union[ir.Type, Type]) -> Type:
return Type.__new__(Array, type)
if isinstance(type, hw.StructType):
return Type.__new__(StructType, type)
if isinstance(type, hw.UnionType):
return Type.__new__(UnionType, type)
if isinstance(type, hw.TypeAliasType):
return Type.__new__(TypeAlias, type, incl_cls_in_key=False)
if isinstance(type, hw.InOutType):
Expand Down Expand Up @@ -409,6 +411,99 @@ def is_hw_type(self) -> bool:
return True


class UnionType(Type):

def __new__(
cls, fields: typing.Union[typing.List[typing.Tuple[str, Type]],
typing.List[typing.Tuple[str, Type, int]],
typing.Dict[str, Type]]
) -> UnionType:
if len(fields) == 0:
raise ValueError("Unions must have at least one field.")
if isinstance(fields, dict):
fields = list(fields.items())
if not isinstance(fields, list):
raise TypeError("Expected either list or dict.")

circt_fields = []
for field in fields:
if len(field) == 2:
circt_fields.append((field[0], field[1]._type, 0))
elif len(field) == 3:
circt_fields.append((field[0], field[1]._type, field[2]))
else:
raise ValueError(
"Fields must be either (name, type) or (name, type, offset)")

return super(UnionType, cls).__new__(cls, hw.UnionType.get(circt_fields))

@property
def is_hw_type(self) -> bool:
return True

@property
def fields(self):
return [(n, _FromCirctType(t), o) for n, t, o in self._type.get_fields()]

def __getattr__(self, attrname: str):
for field in self.fields:
if field[0] == attrname:
return _FromCirctType(self._type.get_field(attrname))
return super().__getattribute__(attrname)

def _get_value_class(self):
from .signals import UnionSignal
return UnionSignal

def __repr__(self) -> str:
ret = "union { "
first = True
for field in self.fields:
if first:
first = False
else:
ret += ", "
ret += f"{field[0]}: {field[1]}"
if field[2] > 0:
ret += f" offset {field[2]}"
ret += "}"
return ret

def _from_obj(self, x, alias: typing.Optional[TypeAlias] = None):
from .dialects import hw
if not isinstance(x, tuple):
raise ValueError(
f"Unions can only be created from tuples, not '{type(x)}'")
if len(x) != 2:
raise ValueError(
"Union tuple must have exactly 2 elements: (name, value)")

name, value = x
if not isinstance(name, str):
raise TypeError("Union field name must be a string")

# Find the field in the union type
field_type = None
field_index = -1
for idx, (fname, ftype, _) in enumerate(self.fields):
if fname == name:
field_type = ftype
field_index = idx
break

if field_type is None:
raise ValueError(f"Field '{name}' not found in union type {self}")

# Convert the value to a signal
val_sig = field_type._from_obj_or_sig(value)

result_type = self if alias is None else alias
with get_user_loc():
return hw.UnionCreateOp(result_type._type,
fieldIndex=field_index,
input=val_sig.value)


class BitVectorType(Type):

@property
Expand Down
37 changes: 36 additions & 1 deletion frontends/PyCDE/test/test_pycde_types.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# RUN: %PYTHON% %s | FileCheck %s

from pycde import dim, Input, Output, generator, System, Module
from pycde.types import Bit, Bits, List, StructType, TypeAlias, UInt, SInt
from pycde.types import Bit, Bits, List, StructType, TypeAlias, UInt, SInt, UnionType
from pycde.testing import unittestmodule
from pycde.signals import Struct, UIntSignal

Expand Down Expand Up @@ -86,3 +86,38 @@ def build(self):
s = ExStruct(a=self.inp1.a, b=self.inp1.get_b_plus1().as_uint(32))
assert type(s) is ExStruct._get_value_class()
self.out2 = s


# CHECK: union { a: Bits<32>, b: Bits<16>}
u = UnionType([("a", Bits(32)), ("b", Bits(16))])
print(u)
# CHECK: [('a', Bits<32>, 0), ('b', Bits<16>, 0)]
print(u.fields)

# CHECK: union { a: Bits<32>, b: Bits<16> offset 32}
u2 = UnionType([("a", Bits(32)), ("b", Bits(16), 32)])
print(u2)
# CHECK: [('a', Bits<32>, 0), ('b', Bits<16>, 32)]
print(u2.fields)


# CHECK-LABEL: hw.module @TestUnion(in %in1 : !hw.union<a: i32, b: i16>, out outA : i32, out outB : i16, out out1 : !hw.union<a: i32, b: i16>, out out2 : !hw.union<a: i32, b: i16>) attributes {output_file = #hw.output_file<"TestUnion.sv", includeReplicatedOps>} {
# CHECK-NEXT: [[R0:%.+]] = hw.union_extract %in1["a"] : !hw.union<a: i32, b: i16>
# CHECK-NEXT: [[R1:%.+]] = hw.union_extract %in1["b"] : !hw.union<a: i32, b: i16>
# CHECK-NEXT: %c456_i16 = hw.constant 456 : i16
# CHECK-NEXT: [[R2:%.+]] = hw.union_create "b", %c456_i16 : !hw.union<a: i32, b: i16>
# CHECK-NEXT: hw.output [[R0]], [[R1]], %in1, [[R2]] : i32, i16, !hw.union<a: i32, b: i16>, !hw.union<a: i32, b: i16>
@unittestmodule()
class TestUnion(Module):
in1 = Input(u)
outA = Output(Bits(32))
outB = Output(Bits(16))
out1 = Output(u)
out2 = Output(u)

@generator
def build(ports):
ports.out1 = ports.in1
ports.outA = ports.in1["a"]
ports.outB = ports.in1.b
ports.out2 = u(("b", 456))
4 changes: 4 additions & 0 deletions lib/Bindings/Python/HWModule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,10 @@ void circt::python::populateDialectHWSubmodule(nb::module_ &m) {
llvm::SmallVector<llvm::SmallString<8>> names;
for (size_t i = 0, e = pyFieldInfos.size(); i < e; ++i) {
auto tuple = nb::cast<nb::tuple>(pyFieldInfos[i]);
if (tuple.size() < 3)
throw std::invalid_argument(
"UnionType field info must be a tuple of (name, type, "
"offset)");
auto type = nb::cast<MlirType>(tuple[1]);
size_t offset = nb::cast<size_t>(tuple[2]);
ctx = mlirTypeGetContext(type);
Expand Down
2 changes: 2 additions & 0 deletions lib/Bindings/Python/support.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,8 @@ def type_to_pytype(t) -> ir.Type:
return hw.ArrayType(t)
if hw.StructType.isinstance(t):
return hw.StructType(t)
if hw.UnionType.isinstance(t):
return hw.UnionType(t)
if hw.TypeAliasType.isinstance(t):
return hw.TypeAliasType(t)
if hw.InOutType.isinstance(t):
Expand Down