Skip to content

Commit

Permalink
[PyCDE][NFC] Create signals from Python objects through Types (#4611)
Browse files Browse the repository at this point in the history
Instead of calling `_obj_to_value` from the support library, use some
OOP to leverage the `Type` system to do it. The `_obj_to_value` function
dates back to when we were using raw `ir.Types` and not extending them.
Now that we have a proper type hierarchy (and can rely on it being
used), this change is possible.
  • Loading branch information
teqdruid authored Feb 2, 2023
1 parent 1e8f4fa commit acc95b2
Show file tree
Hide file tree
Showing 5 changed files with 161 additions and 109 deletions.
8 changes: 4 additions & 4 deletions frontends/PyCDE/src/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
from typing import List, Optional, Set, Tuple, Dict

from .common import (AppID, Clock, Input, Output, PortError, _PyProxy)
from .support import (get_user_loc, _obj_to_attribute, _obj_to_value,
create_type_string, create_const_zero)
from .support import (get_user_loc, _obj_to_attribute, create_type_string,
create_const_zero)
from .signals import ClockSignal, Signal, _FromCirctValue
from .types import ClockType

Expand Down Expand Up @@ -140,7 +140,7 @@ def _set_output(self, idx, signal):
raise PortError(
f"Input port {pname} expected type {ptype}, not {signal.type}")
else:
signal = _obj_to_value(signal, ptype)
signal = ptype(signal)
self._output_values[idx] = signal

def _set_outputs(self, signal_dict: Dict[str, Signal]):
Expand Down Expand Up @@ -433,7 +433,7 @@ def instantiate(self, module_inst, instance_name: str, **inputs):
else:
# If it's not a signal, assume the user wants to specify a constant and
# try to convert it to a hardware constant.
signal = _obj_to_value(signal, ptype)
signal = ptype(signal)
input_values[idx] = signal
del input_lookup[name]

Expand Down
3 changes: 1 addition & 2 deletions frontends/PyCDE/src/signals.py
Original file line number Diff line number Diff line change
Expand Up @@ -668,9 +668,8 @@ def reg(self, clk, rst=None, name=None):

def unwrap(self, ready):
from .dialects import esi
from .support import _obj_to_value
from .types import types
ready = _obj_to_value(ready, types.i1)
ready = types.i1(ready)
unwrap_op = esi.UnwrapValidReadyOp(self.type.inner_type, types.i1,
self.value, ready.value)
return unwrap_op[0], unwrap_op[1]
Expand Down
91 changes: 4 additions & 87 deletions frontends/PyCDE/src/support.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,92 +63,9 @@ def create_const_zero(type):
return hw.BitcastOp(type, zero)


class OpOperandConnect(support.OpOperand):
"""An OpOperand pycde extension which adds a connect method."""

def connect(self, obj, result_type=None):
if result_type is None:
result_type = self.type
val = _obj_to_value(obj, self.type, result_type)
support.connect(self, val)


def _obj_to_value(x, type: "Type", result_type=None) -> ir.Value:
"""Convert a python object to a CIRCT value, given the CIRCT type."""
if x is None:
raise ValueError(
"Encountered 'None' when trying to build hardware for python value.")
from .signals import Signal
from .dialects import hw, hwarith
from .types import (Type, TypeAlias, Array, StructType, BitVectorType, Bits,
UInt, SInt, _FromCirctType)

assert isinstance(type, Type)
if isinstance(x, Signal):
if x.type != type:
raise TypeError(f"expected {x.type}, got {type}")
return x

if isinstance(type, TypeAlias):
return _obj_to_value(x, type.inner_type, type)

if result_type is None:
result_type = type
else:
result_type = _FromCirctType(result_type)
assert isinstance(result_type, TypeAlias) or result_type == type

val = support.get_value(x)
# If x is already a valid value, just return it.
if val is not None:
if val.type != result_type:
raise ValueError(f"Expected {result_type}, got {val.type}")
return val

if isinstance(x, int):
if not isinstance(type, BitVectorType):
raise ValueError(f"Int can only be converted to hw int, not '{type}'")
with get_user_loc():
if isinstance(type, Bits):
return hw.ConstantOp(type, x)
elif isinstance(type, (UInt, SInt)):
return hwarith.ConstantOp(type, x)
else:
assert False, "Internal error: bit vector type unknown"

if isinstance(x, (list, tuple)):
if not isinstance(type, Array):
raise ValueError(f"List is only convertable to hw array, not '{type}'")
elemty = result_type.element_type
if len(x) != type.size:
raise ValueError("List must have same size as array "
f"{len(x)} vs {type.size}")
list_of_vals = list(map(lambda x: _obj_to_value(x, elemty), x))
# CIRCT's ArrayCreate op takes the array in reverse order.
with get_user_loc():
return hw.ArrayCreateOp(reversed(list_of_vals))

if isinstance(x, dict):
if not isinstance(type, StructType):
raise ValueError(f"Dict is only convertable to hw struct, not '{type}'")
elem_name_values = []
for (fname, ftype) in type.fields:
if fname not in x:
raise ValueError(f"Could not find expected field: {fname}")
v = _obj_to_value(x[fname], ftype)
elem_name_values.append((fname, v))
x.pop(fname)
if len(x) > 0:
raise ValueError(f"Extra fields specified: {x}")
with get_user_loc():
return hw.StructCreateOp(elem_name_values, result_type=result_type._type)

raise ValueError(f"Unable to map object '{x}' to MLIR Value")


def _infer_type(x):
"""Infer the CIRCT type from a python object. Only works on lists."""
from .types import Array, _FromCirctType
from .types import Array
from .signals import Signal
if isinstance(x, Signal):
return x.type
Expand All @@ -169,10 +86,10 @@ def _infer_type(x):
def _obj_to_value_infer_type(value) -> ir.Value:
"""Infer the CIRCT type, then convert the Python object to a CIRCT Value of
that type."""
type = _infer_type(value)
if type is None:
cde_type = _infer_type(value)
if cde_type is None:
raise ValueError(f"Cannot infer CIRCT type from '{value}")
return _obj_to_value(value, type)
return cde_type(value)


def create_type_string(ty):
Expand Down
126 changes: 110 additions & 16 deletions frontends/PyCDE/src/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from collections import OrderedDict

from .support import _obj_to_value
from .support import get_user_loc

from .circt import ir, support
from .circt.dialects import esi, hw, sv
Expand All @@ -19,12 +19,12 @@ def __init__(self):
self.registered_aliases = OrderedDict()

def __getattr__(self, name: str) -> ir.Type:
return self.wrap(ir.Type.parse(name))
return self.wrap(_FromCirctType(ir.Type.parse(name)))

def int(self, width: int, name: str = None):
return self.wrap(Bits(width), name)

def array(self, inner: ir.Type, size: int, name: str = None) -> hw.ArrayType:
def array(self, inner: ir.Type, size: int, name: str = None) -> "Array":
return self.wrap(Array(inner, size), name)

def inout(self, inner: ir.Type):
Expand All @@ -33,11 +33,8 @@ def inout(self, inner: ir.Type):
def channel(self, inner):
return self.wrap(Channel(inner))

def struct(self, members, name: str = None) -> hw.StructType:
s = StructType(members)
if name is None:
return s
return TypeAlias(s, name)
def struct(self, members, name: str = None) -> "StructType":
return self.wrap(StructType(members), name)

@property
def any(self):
Expand All @@ -46,7 +43,7 @@ def any(self):
def wrap(self, type, name=None):
if name is not None:
type = TypeAlias(type, name)
return _FromCirctType(type)
return type


types = _Types()
Expand Down Expand Up @@ -85,14 +82,38 @@ def strip(self):
def bitwidth(self):
return hw.get_bitwidth(self._type)

def __call__(self, obj, name: str = None):
def __call__(self, obj, name: str = None) -> "Signal":
"""Create a Value of this type from a python object."""
assert not isinstance(obj, ir.Value)
v = _obj_to_value(obj, self, self)
v = self._from_obj_or_sig(obj)
if name is not None:
v.name = name
return v

def _from_obj_or_sig(self,
obj,
alias: typing.Optional["TypeAlias"] = None) -> "Signal":
"""Implement the object-signal conversion wherein 'obj' can be a Signal. If
'obj' is already a Signal, check its type and return it. Can be overriden by
subclasses, though calls _from_obj() to do the type-specific const
conversion so we recommend subclasses override that method."""

from .signals import Signal
if isinstance(obj, Signal):
if obj.type != self:
raise TypeError(f"Expected signal of type {self} but got {obj.type}")
return obj
return self._from_obj(obj, alias)

def _from_obj(self,
obj,
alias: typing.Optional["TypeAlias"] = None) -> "Signal":
"""Do the type-specific object validity checks and return a Signal from the
object. Can assume the 'obj' is NOT a Signal. Any subclass which wants to be
created MUST override this method."""

assert False, "Subclass must override this method"

def _get_value_class(self):
"""Return the class which should be instantiated to create a Value."""
from .signals import UntypedSignal
Expand Down Expand Up @@ -232,6 +253,9 @@ def _get_value_class(self):
def wrap(self, value):
return self(value)

def _from_obj(self, obj, alias: typing.Optional["TypeAlias"] = None):
return self.inner_type._from_obj_or_sig(obj, alias=self)


class Array(Type):

Expand Down Expand Up @@ -273,6 +297,20 @@ def __repr__(self) -> str:
def __str__(self) -> str:
return f"{self.element_type}[{self.size}]"

def _from_obj(self, obj, alias: typing.Optional[TypeAlias] = None):
from .dialects import hw
if not isinstance(obj, (list, tuple)):
raise ValueError(
f"Arrays can only be created from lists or tuples, not '{type(obj)}'")
if len(obj) != self.size:
raise ValueError("List must have same size as array "
f"{len(obj)} vs {self.size}")
elemty = self.element_type
list_of_vals = list(map(lambda x: elemty._from_obj_or_sig(x), obj))
with get_user_loc():
# CIRCT's ArrayCreate op takes the array in reverse order.
return hw.ArrayCreateOp(reversed(list_of_vals))


class StructType(Type):

Expand All @@ -299,6 +337,25 @@ def _get_value_class(self):
from .signals import StructSignal
return StructSignal

def _from_obj(self, x, alias: typing.Optional[TypeAlias] = None):
from .dialects import hw
if not isinstance(x, dict):
raise ValueError(
f"Structs can only be created from dicts, not '{type(x)}'")
elem_name_values = []
for (fname, ftype) in self.fields:
if fname not in x:
raise ValueError(f"Could not find expected field: {fname}")
v = ftype._from_obj_or_sig(x[fname])
elem_name_values.append((fname, v))
x.pop(fname)
if len(x) > 0:
raise ValueError(f"Extra fields specified: {x}")

result_type = self if alias is None else alias
with get_user_loc():
return hw.StructCreateOp(elem_name_values, result_type=result_type._type)

def __repr__(self) -> str:
ret = "struct { "
first = True
Expand All @@ -307,7 +364,7 @@ def __repr__(self) -> str:
first = False
else:
ret += ", "
ret += f"{field[0]}: {_FromCirctType(field[1])}"
ret += f"{field[0]}: {field[1]}"
ret += "}"
return ret

Expand All @@ -324,7 +381,7 @@ def __new__(cls, fields: typing.List[typing.Tuple[str, Type]], name: str,
return inst

def __call__(self, **kwargs):
return _obj_to_value(kwargs, self)
return self._from_obj_or_sig(kwargs)

def _get_value_class(self):
return self._value_class
Expand All @@ -336,6 +393,15 @@ class BitVectorType(Type):
def width(self):
return self._type.width

def _from_obj_check(self, x):
"""This functionality can be shared by all the int types."""
if not isinstance(x, int):
raise ValueError(f"{type(self).__name__} can only be created from ints, "
f"not {type(x).__name__}")
signed_bit = 1 if isinstance(self, SInt) else 0
if x.bit_length() + signed_bit > self.width:
raise ValueError(f"{x} overflows type {self}")

def __repr__(self) -> str:
return f"{type(self).__name__}<{self.width}>"

Expand All @@ -352,6 +418,12 @@ def _get_value_class(self):
from .signals import BitsSignal
return BitsSignal

def _from_obj(self, x: int, alias: typing.Optional[TypeAlias] = None):
from .dialects import hw
self._from_obj_check(x)
circt_type = self if alias is None else alias
return hw.ConstantOp(circt_type, x)


class SInt(BitVectorType):

Expand All @@ -365,6 +437,12 @@ def _get_value_class(self):
from .signals import SIntSignal
return SIntSignal

def _from_obj(self, x: int, alias: typing.Optional[TypeAlias] = None):
from .dialects import hwarith
self._from_obj_check(x)
circt_type = self if alias is None else alias
return hwarith.ConstantOp(circt_type, x)


class UInt(BitVectorType):

Expand All @@ -378,6 +456,14 @@ def _get_value_class(self):
from .signals import UIntSignal
return UIntSignal

def _from_obj(self, x: int, alias: typing.Optional[TypeAlias] = None):
from .dialects import hwarith
self._from_obj_check(x)
if x < 0:
raise ValueError(f"UInt can only store positive numbers, not {x}")
circt_type = self if alias is None else alias
return hwarith.ConstantOp(circt_type, x)


class ClockType(Bits):
"""A special single bit to represent a clock. Can't do any special operations
Expand Down Expand Up @@ -426,10 +512,18 @@ def __repr__(self):
def inner(self):
return self.inner_type

def wrap(self, value, valid):
def wrap(self, value, valid) -> typing.Tuple["ChannelSignal", "BitsSignal"]:
"""Wrap a data signal and valid signal into a data channel signal and a
ready signal."""

# Instead of implementing __call__(), we require users to call this method
# instead. In addition to being clearer, the type signature isn't the same
# -- this returns a tuple of Signals (data, ready) -- rather than a single
# one.

from .dialects import esi
value = _obj_to_value(value, self.inner_type)
valid = _obj_to_value(valid, types.i1)
value = self.inner_type(value)
valid = types.i1(valid)
wrap_op = esi.WrapValidReadyOp(self._type, types.i1, value.value,
valid.value)
return wrap_op[0], wrap_op[1]
Expand Down
Loading

0 comments on commit acc95b2

Please sign in to comment.