diff --git a/frontends/PyCDE/integration_test/esi_ram.py b/frontends/PyCDE/integration_test/esi_ram.py index abbcd45e783d..33a3624b5755 100644 --- a/frontends/PyCDE/integration_test/esi_ram.py +++ b/frontends/PyCDE/integration_test/esi_ram.py @@ -6,7 +6,7 @@ # PY: run_cosim(tmpdir, rpcschemapath, simhostport) import pycde -from pycde import (Clock, Input, module, generator, types) +from pycde import (Clock, Input, Module, generator, types) from pycde.constructs import Wire from pycde import esi @@ -24,8 +24,7 @@ class MemComms: to_client_type=WriteType) -@module -class Mid: +class Mid(Module): clk = Clock(types.i1) rst = Input(types.i1) @@ -40,8 +39,7 @@ def construct(ports): RamI64x8.write(write_data) -@module -class top: +class top(Module): clk = Clock(types.i1) rst = Input(types.i1) diff --git a/frontends/PyCDE/integration_test/esi_test.py b/frontends/PyCDE/integration_test/esi_test.py index a28a62f2308e..b9fe89b5e67a 100644 --- a/frontends/PyCDE/integration_test/esi_test.py +++ b/frontends/PyCDE/integration_test/esi_test.py @@ -6,11 +6,10 @@ # PY: run_cosim(tmpdir, rpcschemapath, simhostport) import pycde -from pycde import (Clock, Input, InputChannel, OutputChannel, module, generator, +from pycde import (Clock, Input, InputChannel, OutputChannel, Module, generator, types) from pycde import esi from pycde.constructs import Wire -from pycde.dialects import comb import sys @@ -23,8 +22,7 @@ class HostComms: to_client_type=types.i32) -@module -class Producer: +class Producer(Module): clk = Input(types.i1) int_out = OutputChannel(types.i32) @@ -34,8 +32,7 @@ def construct(ports): ports.int_out = chan -@module -class Consumer: +class Consumer(Module): clk = Input(types.i1) int_in = InputChannel(types.i32) @@ -44,8 +41,7 @@ def construct(ports): HostComms.to_host(ports.int_in, "loopback_out") -@module -class LoopbackInOutAdd7: +class LoopbackInOutAdd7(Module): @generator def construct(ports): @@ -59,8 +55,7 @@ def construct(ports): loopback.assign(data_chan) -@module -class Mid: +class Mid(Module): clk = Clock(types.i1) rst = Input(types.i1) @@ -72,8 +67,7 @@ def construct(ports): LoopbackInOutAdd7() -@module -class Top: +class Top(Module): clk = Clock(types.i1) rst = Input(types.i1) diff --git a/frontends/PyCDE/integration_test/pytorch/dot_prod_system.py b/frontends/PyCDE/integration_test/pytorch/dot_prod_system.py index a920f67f40fb..eb2b1c659955 100644 --- a/frontends/PyCDE/integration_test/pytorch/dot_prod_system.py +++ b/frontends/PyCDE/integration_test/pytorch/dot_prod_system.py @@ -7,7 +7,7 @@ # PY: from dot_prod_system import run_cosim # PY: run_cosim(tmpdir, rpcschemapath, simhostport) -from pycde import Input, module, generator, types +from pycde import Input, Module, generator, types from pycde.common import Clock from pycde.system import System from pycde.esi import FromServer, ToFromServer, ServiceDecl, CosimBSP @@ -41,8 +41,7 @@ # output_type="linalg-on-tensors") -@module -class Gasket: +class Gasket(Module): """Wrap the accelerator IP module. Instantiate the requiste memories. Wire the memories to the host and the host to the module control signals.""" diff --git a/frontends/PyCDE/src/__init__.py b/frontends/PyCDE/src/__init__.py index 705431d2b304..fed5ff45b211 100644 --- a/frontends/PyCDE/src/__init__.py +++ b/frontends/PyCDE/src/__init__.py @@ -2,12 +2,6 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -from .common import (AppID, Clock, Input, InputChannel, Output, OutputChannel) -from .module import (externmodule, generator, module, no_connect) -from .system import (System) -from .pycde_types import (dim, types) -from .value import (Value) - from .circt import ir from . import circt import atexit @@ -24,6 +18,12 @@ def __exit_ctxt(): DefaultContext.__exit__(None, None, None) +from .common import (AppID, Clock, Input, InputChannel, Output, OutputChannel) +from .module import (generator, modparams, Module) +from .system import (System) +from .pycde_types import (dim, types) +from .value import (Value) + # Until we get source location based on Python stack traces, default to unknown # locations. DefaultLocation = ir.Location.unknown() diff --git a/frontends/PyCDE/src/common.py b/frontends/PyCDE/src/common.py index a2f28e5c00d6..0569d7e32c84 100644 --- a/frontends/PyCDE/src/common.py +++ b/frontends/PyCDE/src/common.py @@ -88,3 +88,7 @@ class _PyProxy: def __init__(self, name: str): self.name = name + + +class PortError(Exception): + pass diff --git a/frontends/PyCDE/src/constructs.py b/frontends/PyCDE/src/constructs.py index b33048364503..1caacfbd4d9a 100644 --- a/frontends/PyCDE/src/constructs.py +++ b/frontends/PyCDE/src/constructs.py @@ -8,7 +8,7 @@ from .pycde_types import PyCDEType, dim, types from .value import BitsSignal, BitVectorSignal, ListValue, Value, Signal from .value import get_slice_bounds -from .module import generator, module, _BlockContext +from .module import generator, modparams, Module, _BlockContext from .circt.support import get_value, BackedgeBuilder from .circt.dialects import msft, hw, sv from pycde.dialects import comb @@ -149,10 +149,10 @@ def ControlReg(clk: Signal, rst: Signal, asserts: List[Signal], opposite. If both an assert and a reset are active on the same cycle, the assert takes priority.""" - @module + @modparams def ControlReg(num_asserts: int, num_resets: int): - class ControlReg: + class ControlReg(Module): clk = Clock() rst = Input(types.i1) out = Output(types.i1) diff --git a/frontends/PyCDE/src/esi.py b/frontends/PyCDE/src/esi.py index 2f4eaef500d3..6c90a555c609 100644 --- a/frontends/PyCDE/src/esi.py +++ b/frontends/PyCDE/src/esi.py @@ -3,10 +3,9 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception from pycde.system import System -from .module import (Generator, _module_base, _BlockContext, - _GeneratorPortAccess, _SpecializedModule) +from .module import Generator, _BlockContext, Module, ModuleLikeBuilderBase from pycde.value import ChannelValue, ClockSignal, Signal, Value -from .common import AppID, Input, Output, InputChannel, OutputChannel, _PyProxy +from .common import Input, Output, InputChannel, OutputChannel, _PyProxy from .circt.dialects import esi as raw_esi, hw, msft from .circt.support import BackedgeBuilder from pycde.pycde_types import ChannelType, ClockType, PyCDEType, types @@ -15,7 +14,7 @@ from pathlib import Path import shutil -from typing import List, Optional, Type +from typing import Dict, List, Optional, Type __dir__ = Path(__file__).parent @@ -180,11 +179,10 @@ def Cosim(decl: ServiceDecl, clk, rst): def CosimBSP(user_module): """Wrap and return a cosimulation 'board support package' containing 'user_module'""" - from .module import module, generator + from .module import Module, generator from .common import Clock, Input - @module - class top: + class top(Module): clk = Clock() rst = Input(types.int(1)) @@ -256,8 +254,7 @@ class _ServiceGeneratorChannels: """Provide access to the channels which the service generator is responsible for connecting up.""" - def __init__(self, mod: _SpecializedModule, - req: raw_esi.ServiceImplementReqOp): + def __init__(self, mod: Module, req: raw_esi.ServiceImplementReqOp): self._req = req portReqsBlock = req.portReqs.blocks[0] @@ -269,7 +266,7 @@ def __init__(self, mod: _SpecializedModule, ] # Find the output channel requests and store the settable proxies. - num_output_ports = len(mod.output_port_lookup) + num_output_ports = len(mod.outputs) self._output_reqs = [ _OutputChannelSetter(req, self._req.results[num_output_ports + idx]) for idx, req in enumerate(portReqsBlock) @@ -294,75 +291,76 @@ def check_unconnected_outputs(self): raise ValueError(f"{name_str} has not been connected.") -def ServiceImplementation(decl: Optional[ServiceDecl]): +class ServiceImplementationModuleBuilder(ModuleLikeBuilderBase): + """Define how to build ESI service implementations. Unlike Modules, there is + no distinction between definition and instance -- ESI service providers are + built where they are instantiated.""" + + def instantiate(self, impl, instance_name: str, **inputs): + # Each instantiation of the ServiceImplementation has its own + # registration. + opts = _service_generator_registry.register(impl) + + # Create the op. + decl_sym = None + if impl.decl is not None: + decl_sym = ir.FlatSymbolRefAttr.get(impl.decl._materialize_service_decl()) + return raw_esi.ServiceInstanceOp( + result=[t for _, t in self.outputs], + service_symbol=decl_sym, + impl_type=_ServiceGeneratorRegistry._impl_type_name, + inputs=[inputs[pn].value for pn, _ in self.inputs], + impl_opts=opts, + loc=self.loc) + + def generate_svc_impl(self, serviceReq: raw_esi.ServiceInstanceOp): + """"Generate the service inline and replace the `ServiceInstanceOp` which is + being implemented.""" + + assert len(self.generators) == 1 + generator: Generator = list(self.generators.values())[0] + ports = self.generator_port_proxy(serviceReq.operation.operands, self) + with self.GeneratorCtxt(self, ports, serviceReq, generator.loc): + + # Run the generator. + channels = _ServiceGeneratorChannels(self, serviceReq) + rc = generator.gen_func(ports, channels=channels) + if rc is None: + rc = True + elif not isinstance(rc, bool): + raise ValueError("Generators must a return a bool or None") + ports._check_unconnected_outputs() + channels.check_unconnected_outputs() + + # Replace the output values from the service implement request op with + # the generated values. Erase the service implement request op. + for idx, port_value in enumerate(ports._output_values): + msft.replaceAllUsesWith(serviceReq.operation.results[idx], + port_value.value) + serviceReq.operation.erase() + + return rc + + +class ServiceImplementation(Module): """A generator for a service implementation. Must contain a @generator method which will be called whenever required to implement the server. Said generator function will be called with the same 'ports' argument as modules and a 'channels' argument containing lists of the input and output channels which need to be connected to the service being implemented.""" - def wrap(service_impl, decl: Optional[ServiceDecl] = decl): - - def instantiate_cb(mod: _SpecializedModule, instance_name: str, - inputs: dict, appid: AppID, loc): - # Each instantiation of the ServiceImplementation has its own - # registration. - opts = _service_generator_registry.register(mod) - decl_sym = None - if decl is not None: - decl_sym = ir.FlatSymbolRefAttr.get(decl._materialize_service_decl()) - return raw_esi.ServiceInstanceOp( - result=[t for _, t in mod.output_ports], - service_symbol=decl_sym, - impl_type=_ServiceGeneratorRegistry._impl_type_name, - inputs=[inputs[pn].value for pn, _ in mod.input_ports], - impl_opts=opts, - loc=loc) - - def generate(generator: Generator, spec_mod: _SpecializedModule, - serviceReq: raw_esi.ServiceInstanceOp): - arguments = serviceReq.operation.operands - with ir.InsertionPoint( - serviceReq), generator.loc, BackedgeBuilder(), _BlockContext(): - # Insert generated code after the service instance op. - ports = _GeneratorPortAccess(spec_mod, arguments) - - # Enter clock block implicitly if only one clock given - clk = None - if len(spec_mod.clock_ports) == 1: - clk_port = list(spec_mod.clock_ports.values())[0] - clk = ClockSignal(arguments[clk_port], ClockType()) - clk.__enter__() - - # Run the generator. - channels = _ServiceGeneratorChannels(spec_mod, serviceReq) - rc = generator.gen_func(ports, channels=channels) - if rc is None: - rc = True - elif not isinstance(rc, bool): - raise ValueError("Generators must a return a bool or None") - ports.check_unconnected_outputs() - channels.check_unconnected_outputs() - - # Replace the output values from the service implement request op with - # the generated values. Erase the service implement request op. - for port_name, port_value in ports._output_values.items(): - port_num = spec_mod.output_port_lookup[port_name] - msft.replaceAllUsesWith(serviceReq.operation.results[port_num], - port_value.value) - serviceReq.operation.erase() - - if clk is not None: - clk.__exit__(None, None, None) - - return rc - - return _module_base(service_impl, - extern_name=None, - generator_cb=generate, - instantiate_cb=instantiate_cb) - - return wrap + BuilderType = ServiceImplementationModuleBuilder + + def __init__(self, decl: Optional[ServiceDecl], **inputs): + """Instantiate a service provider for service declaration 'decl'. If decl, + implementation is expected to handle any and all service declarations.""" + + self.decl = decl + super().__init__(**inputs) + + @property + def name(self): + return self.__class__.__name__ class _ServiceGeneratorRegistry: @@ -372,7 +370,7 @@ class _ServiceGeneratorRegistry: _impl_type_name = ir.StringAttr.get("pycde") def __init__(self): - self._registry = {} + self._registry: Dict[str, ServiceImplementation] = {} # Register myself with ESI so I can dispatch to my internal registry. assert _ServiceGeneratorRegistry._registered is False, \ @@ -382,7 +380,8 @@ def __init__(self): self._implement_service) _ServiceGeneratorRegistry._registered = True - def register(self, service_implementation: _SpecializedModule) -> ir.DictAttr: + def register(self, + service_implementation: ServiceImplementation) -> ir.DictAttr: """Register a ServiceImplementation generator with the PyCDE generator. Called when the ServiceImplamentation is defined.""" @@ -407,7 +406,7 @@ def _implement_service(self, req: ir.Operation): return False (impl, sys) = self._registry[impl_name] with sys: - return impl.generate(serviceReq=req.opview) + return impl._builder.generate_svc_impl(serviceReq=req.opview) _service_generator_registry = _ServiceGeneratorRegistry() diff --git a/frontends/PyCDE/src/fsm.py b/frontends/PyCDE/src/fsm.py index 20e5810db283..29f0b470647a 100644 --- a/frontends/PyCDE/src/fsm.py +++ b/frontends/PyCDE/src/fsm.py @@ -1,5 +1,5 @@ -from pycde import Input, Output, module, generator -from pycde.module import _SpecializedModule, Generator, _GeneratorPortAccess +from pycde import Input, Output, generator +from pycde.module import Generator, Module from pycde.dialects import fsm from pycde.pycde_types import types from typing import Callable @@ -71,7 +71,7 @@ def States(n): return [State() for _ in range(n)] -def create_fsm_machine_op(sys, mod: _SpecializedModule, symbol): +def create_fsm_machine_op(sys, mod: Module, symbol): """Creation callback for creating a FSM MachineOp.""" # Add attributes for in- and output names. @@ -94,8 +94,7 @@ def create_fsm_machine_op(sys, mod: _SpecializedModule, symbol): ip=sys._get_ip()) -def generate_fsm_machine_op(generate_obj: Generator, - spec_mod: _SpecializedModule): +def generate_fsm_machine_op(generate_obj: Generator, spec_mod): """ Generator callback for generating an FSM op. """ entry_block = spec_mod.circt_mod.body.blocks[0] ports = _GeneratorPortAccess(spec_mod, entry_block.arguments) diff --git a/frontends/PyCDE/src/instance.py b/frontends/PyCDE/src/instance.py index 632a8e496bfc..667da817fc8f 100644 --- a/frontends/PyCDE/src/instance.py +++ b/frontends/PyCDE/src/instance.py @@ -15,11 +15,11 @@ class Instance: """Represents a _specific_ instance, unique in a design. This is in contrast to a module instantiation within another module.""" - from .module import _SpecializedModule + from .module import Module __slots__ = ["parent", "inside_of", "root", "symbol", "_op_cache"] - def __init__(self, parent: Instance, inside_of: _SpecializedModule, + def __init__(self, parent: Instance, inside_of: Module, symbol: Optional[ir.Attribute]): """ Construct a new instance. Since the terminology can be confusing: @@ -106,10 +106,10 @@ class ModuleInstance(Instance): """Instance specialization for modules. Since they are the only thing which can contain operations (for now), put all of the children stuff in here.""" - from .module import _SpecializedModule + from .module import Module def __init__(self, parent: Instance, instance_sym: Optional[ir.Attribute], - inside_of: _SpecializedModule, tgt_mod: _SpecializedModule): + inside_of: Module, tgt_mod: Module): super().__init__(parent, inside_of, instance_sym) self.tgt_mod = tgt_mod self._child_cache: Dict[ir.StringAttr, Instance] = None @@ -258,11 +258,11 @@ def __iter__(self) -> Iterator[Instance]: class RegInstance(Instance): """Instance specialization for registers.""" - from .module import _SpecializedModule + from .module import Module __slots__ = ["type"] - def __init__(self, parent: Instance, inside_of: _SpecializedModule, + def __init__(self, parent: Instance, inside_of: Module, symbol: Optional[ir.Attribute], static_op: seq.CompRegOp): super().__init__(parent, inside_of, symbol) @@ -286,10 +286,9 @@ class InstanceHierarchyRoot(ModuleInstance): 'top' module). Plus, CIRCT models it this way. """ import pycde.system as cdesys - from .module import _SpecializedModule + from .module import Module - def __init__(self, module: _SpecializedModule, instance_name: str, - sys: cdesys.System): + def __init__(self, module: Module, instance_name: str, sys: cdesys.System): self.instance_name = instance_name self.system = sys self.root = self diff --git a/frontends/PyCDE/src/module.py b/frontends/PyCDE/src/module.py index ee432bac1ac9..a4e55c33a4c4 100644 --- a/frontends/PyCDE/src/module.py +++ b/frontends/PyCDE/src/module.py @@ -3,12 +3,12 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception from __future__ import annotations -from typing import Optional, Tuple, Union, Dict +from typing import List, Optional, Tuple, Union, Dict from pycde.pycde_types import ClockType from pycde.support import _obj_to_value -from .common import (AppID, Clock, Input, Output, _PyProxy) +from .common import (AppID, Clock, Input, Output, PortError, _PyProxy) from .support import (get_user_loc, _obj_to_attribute, OpOperandConnect, create_type_string, create_const_zero) from .value import ClockSignal, Signal, Value @@ -23,10 +23,10 @@ import sys # A memoization table for module parameterization function calls. -_MODULE_CACHE: Dict[Tuple[builtins.function, mlir.ir.DictAttr], object] = {} +_MODULE_CACHE: Dict[Tuple[builtins.function, ir.DictAttr], object] = {} -def _create_module_name(name: str, params: mlir.ir.DictAttr): +def _create_module_name(name: str, params: ir.DictAttr): """Create a "reasonable" module name from a base name and a set of parameters. E.g. PolyComputeForCoeff_62_42_6.""" @@ -53,279 +53,450 @@ def val_str(val): return ret.strip("_") -def create_msft_module_op(sys, mod: _SpecializedModule, symbol): - """Creation callback for creating a MSFTModuleOp.""" - return msft.MSFTModuleOp(symbol, - mod.input_ports, - mod.output_ports, - mod.parameters, - loc=mod.loc, - ip=sys._get_ip()) +def _get_module_cache_key(func, + params) -> Tuple[builtins.function, ir.DictAttr]: + """The "module" cache is specifically for parameterized modules. It maps the + module parameterization function AND parameter values to the class which was + generated by a previous call to said module parameterization function.""" + if not isinstance(params, ir.DictAttr): + params = _obj_to_attribute(params) + return (func, params) + + +_current_block_context = ContextVar("current_block_context") + + +class _BlockContext: + """Bookkeeping for a generator scope.""" + + def __init__(self): + self.symbols: set[str] = set() + + @staticmethod + def current() -> _BlockContext: + """Get the top-most context in the stack created by `with + _BlockContext()`.""" + bb = _current_block_context.get(None) + assert bb is not None + return bb + def __enter__(self): + self._old_system_token = _current_block_context.set(self) -def generate_msft_module_op(generator: Generator, spec_mod: _SpecializedModule, - **kwargs): + def __exit__(self, exc_type, exc_value, traceback): + if exc_value is not None: + return + _current_block_context.reset(self._old_system_token) + + def uniquify_symbol(self, sym: str) -> str: + """Create a unique symbol and add it to the cache. If it is to be preserved, + the caller must use it as the symbol on a top-level op.""" + ctr = 0 + ret = sym + while ret in self.symbols: + ctr += 1 + ret = f"{sym}_{ctr}" + self.symbols.add(ret) + return ret - def create_output_op(args: _GeneratorPortAccess): - """Create the hw.OutputOp from module I/O ports in 'args'.""" - output_ports = spec_mod.output_ports - outputs: list[Value] = list() +class Generator: + """ + Represents a generator. Stores the generate function and location of + the generate call. Generator objects are passed to module-specific generator + object handlers. + """ + + def __init__(self, gen_func): + self.gen_func = gen_func + self.loc = get_user_loc() + + +def generator(func): + """Decorator for generation functions.""" + return Generator(func) + + +class PortProxyBase: + """Extensions of this class provide access to module ports in generators. + Subclasses essentially just provide syntactic sugar around the methods in this + base class. None of the methods here are intended to be directly used by the + PyCDE developer.""" + + def __init__(self, block_args, builder): + self._block_args = block_args + self._output_values = [None] * len(builder.outputs) + self._builder = builder + + def _get_input(self, idx): + val = self._block_args[idx] + if idx in self._builder.clocks: + return ClockSignal(val, ClockType()) + return Value(val) + + def _set_output(self, idx, signal): + assert signal is not None + pname, ptype = self._builder.outputs[idx] + if isinstance(signal, Signal): + if ptype != signal.type: + raise PortError( + f"Input port {pname} expected type {ptype}, not {signal.type}") + else: + signal = _obj_to_value(signal, ptype) + self._output_values[idx] = signal + + def _check_unconnected_outputs(self): unconnected_ports = [] - for (name, _) in output_ports: - if name not in args._output_values: - unconnected_ports.append(name) - outputs.append(None) - else: - outputs.append(args._output_values[name]) + for idx, value in enumerate(self._output_values): + if value is None: + unconnected_ports.append(self._builder.outputs[idx][0]) if len(unconnected_ports) > 0: - raise support.UnconnectedSignalError(spec_mod.name, unconnected_ports) - - msft.OutputOp([o.value for o in outputs]) - - bc = _BlockContext() - entry_block = spec_mod.circt_mod.add_entry_block() - with ir.InsertionPoint(entry_block), generator.loc, BackedgeBuilder(), bc: - args = _GeneratorPortAccess(spec_mod, entry_block.arguments) - - # Enter clock block implicitly if only one clock given - clk = None - if len(spec_mod.clock_ports) == 1: - clk_port = list(spec_mod.clock_ports.values())[0] - val = entry_block.arguments[clk_port] - clk = ClockSignal(val, ClockType()) - clk.__enter__() - - outputs = generator.gen_func(args, **kwargs) - if outputs is not None: - raise ValueError("Generators must not return a value") - if create_output_op is not None: - create_output_op(args) - - if clk is not None: - clk.__exit__(None, None, None) - - -def create_msft_module_extern_op(sys, mod: _SpecializedModule, symbol): - """Creation callback for creating a MSFTModuleExternOp.""" - paramdecl_list = [ - hw.ParamDeclAttr.get_nodefault(i.name, i.attr.type) - for i in mod.parameters - ] - return msft.MSFTModuleExternOp( - symbol, - mod.input_ports, - mod.output_ports, - parameters=paramdecl_list, - attributes={"verilogName": ir.StringAttr.get(mod.extern_name)}, - loc=mod.loc, - ip=sys._get_ip()) - - -class _SpecializedModule(_PyProxy): - """SpecializedModule serves two purposes: - - (1) As a level of indirection between pure python and python CIRCT op - classes. This indirection makes it possible to invalidate the reference and - clean up when those ops may not exist anymore. - - (2) It delays module op creation until there is a valid context and system to - create it in. As a result of how the delayed creation works, module ops are - only created if said module is instantiated.""" - - __slots__ = [ - "generators", - "modcls", - "loc", - "clock_ports", - "input_ports", - "input_port_lookup", - "output_ports", - "output_port_lookup", - "parameters", - "extern_name", - "create_cb", - "generator_cb", - "instantiate_cb", - ] - - def __init__(self, - orig_cls: type, - cls: type, - parameters: Union[dict, ir.DictAttr], - extern_name: str, - create_cb: Optional[builtins.function], - generator_cb: Optional[builtins.function] = None, - instantiate_cb: Optional[builtins.function] = None): + raise support.UnconnectedSignalError(self.name, unconnected_ports) + + def _clear(self): + """TL;DR: Downgrade a shotgun to a handgun. + + Instances are not _supposed_ to be held on beyond the end of generators... + but at least one user will try. This method clears the contents of this + class to prevent users from encountering a totally bizzare, unrelated error + message when they make this mistake. If users reach into this class and hold + on to private references... well, we did what we could to prevent their foot + damage.""" + + self._block_args = None + self._output_values = None + self._builder = None + + +class ModuleLikeBuilderBase(_PyProxy): + """`ModuleLikeBuilder`s are responsible for preparing `Module` and other + module-like subclasses for use. They are responsible for scanning the + subclass' attribute, recognizing certain types (e.g. `InputPort`), and taking + actions/mutating the subclass based on that information. They are also + responsible for creating CIRCT IR -- creating the initial op, generating the + bodies, instantiating modules, etc. + + This is the base class for common functionality which all 'ModuleLike` classes + are likely to need. Each `ModuleLike` type will need to subclass this base. + For instance, plain 'ol `Module`s have a corresponding subclass called + `ModuleBuilder`. The correspondence is given by the `BuilderType` class + variable in `Module`.""" + + def __init__(self, cls, cls_dct, loc): self.modcls = cls - self.extern_name = extern_name - self.loc = get_user_loc() - self.create_cb = create_cb - self.generator_cb = generator_cb - self.instantiate_cb = instantiate_cb - - # Make sure 'parameters' is a DictAttr rather than a python value. - self.parameters = _obj_to_attribute(parameters) - assert isinstance(self.parameters, ir.DictAttr) - - # Get the module name - if extern_name is not None: - self.name = extern_name - elif "module_name" in dir(cls): - self.name = cls.module_name - elif "get_module_name" in dir(cls): - self.name = cls.get_module_name() - else: - self.name = _create_module_name(cls.__name__, self.parameters) - - # Inputs, Outputs, and parameters are all class members. We must populate - # them. Scan 'cls' for them. - self.input_ports = [] - self.input_port_lookup: Dict[str, int] = {} # Used by 'BlockArgs' below. - self.output_port_lookup: Dict[str, int] = {} # Used by 'BlockArgs' below. - self.output_ports = [] - self.generators = {} - self.clock_ports: Dict[str, int] = {} - for attr_name in vars(orig_cls): + self.cls_dct = cls_dct + self.loc = loc + + self.outputs = None + self.inputs = None + self.clocks = None + self.generators = None + self.generator_port_proxy = None + self.parameters = None + + def go(self): + """Execute the analysis and mutation to make a `ModuleLike` class operate + as such.""" + + self.scan_cls() + self.generator_port_proxy = self.create_port_proxy() + self.add_external_port_accessors() + + def scan_cls(self): + """Scan the class for input/output ports and generators. (Most `ModuleLike` + will use these.) Store the results for later use.""" + + input_ports = [] + output_ports = [] + clock_ports = set() + generators = {} + for attr_name, attr in self.cls_dct.items(): if attr_name.startswith("_"): continue - attr = getattr(cls, attr_name) - if isinstance(attr, Input): - attr.name = attr_name - self.input_ports.append((attr.name, attr.type)) - self.input_port_lookup[attr_name] = len(self.input_ports) - 1 + if isinstance(attr, Clock): + clock_ports.add(len(input_ports)) + input_ports.append((attr_name, ir.IntegerType.get_signless(1))) + elif isinstance(attr, Input): + input_ports.append((attr_name, attr.type)) elif isinstance(attr, Output): - attr.name = attr_name - self.output_ports.append((attr.name, attr.type)) - self.output_port_lookup[attr_name] = len(self.output_ports) - 1 + output_ports.append((attr_name, attr.type)) elif isinstance(attr, Generator): - self.generators[attr_name] = attr + generators[attr_name] = attr - if isinstance(attr, Clock): - self.clock_ports[attr.name] = len(self.input_ports) - 1 - self.add_accessors() - - def add_accessors(self): - """Add accessors for each input and output port to emulate generated OpView - subclasses.""" - for (idx, (name, type)) in enumerate(self.input_ports): - setattr( - self.modcls, name, - property(lambda self, idx=idx: OpOperandConnect( - self._instantiation.operation, idx, self._instantiation.operation. - operands[idx], self))) - for (idx, (name, type)) in enumerate(self.output_ports): - setattr( - self.modcls, name, - property(lambda self, idx=idx, type=type: Value( - self._instantiation.operation.results[idx], type))) - - # Bug: currently only works with one System. See notes at the top of this - # class. - def create(self): - """Create the module op. Should not be called outside of a 'System' - context. Returns the symbol of the module op.""" - - if self.create_cb is None: - return + self.outputs = output_ports + self.inputs = input_ports + self.clocks = clock_ports + self.generators = generators - from .system import System - sys = System.current() - sys._create_circt_mod(self) + def create_port_proxy(self): + """Create a proxy class for generators to use in order to access module + ports. Instances of this will (usually) be used in place of the `self` + argument in generator calls. + + Replaces the dynamic lookup scheme previously utilized. Should be faster and + (more importantly) reduces the amount of bookkeeping necessary.""" + + proxy_attrs = {} + for idx, (name, port_type) in enumerate(self.inputs): + proxy_attrs[name] = property(lambda self, idx=idx: self._get_input(idx)) + + for idx, (name, port_type) in enumerate(self.outputs): + + def fget(self, idx=idx): + self._get_output(idx) + + def fset(self, val, idx=idx): + self._set_output(idx, val) + + proxy_attrs[name] = property(fget=fget, fset=fset) + + return type(self.modcls.__name__ + "Ports", (PortProxyBase,), proxy_attrs) + + def add_external_port_accessors(self): + """For each port, replace it with a property to provide access to the + instances output in OTHER generators which are instantiating this module.""" + + for idx, (name, port_type) in enumerate(self.inputs): + + def fget(self): + raise PortError("Cannot access signal via instance input") + + setattr(self.modcls, name, property(fget=fget)) + + for idx, (name, port_type) in enumerate(self.outputs): + + def fget(self, idx=idx): + return Value(self.inst.results[idx]) + + setattr(self.modcls, name, property(fget=fget)) @property - def is_created(self): - return self.circt_mod is not None + def name(self): + if hasattr(self.modcls, "module_name"): + return self.modcls.module_name + elif self.parameters is not None: + return _create_module_name(self.modcls.__name__, self.parameters) + else: + return self.modcls.__name__ + + def print(self, out): + print( + f"", + file=out) + + class GeneratorCtxt: + """Provides an context which most genertors need.""" + + def __init__(self, builder: ModuleLikeBuilderBase, ports: PortProxyBase, ip, + loc: ir.Location) -> None: + self.bc = _BlockContext() + self.bb = BackedgeBuilder() + self.ip = ir.InsertionPoint(ip) + self.loc = loc + self.clk = None + self.ports = ports + if len(builder.clocks) == 1: + # Enter clock block implicitly if only one clock given. + clk_port = list(builder.clocks)[0] + self.clk = ClockSignal(ports._block_args[clk_port], ClockType()) + + def __enter__(self): + self.bc.__enter__() + self.bb.__enter__() + self.ip.__enter__() + self.loc.__enter__() + if self.clk is not None: + self.clk.__enter__() + + def __exit__(self, exc_type, exc_value, traceback): + if self.clk is not None: + self.clk.__exit__(exc_type, exc_value, traceback) + self.loc.__exit__(exc_type, exc_value, traceback) + self.ip.__exit__(exc_type, exc_value, traceback) + self.bb.__exit__(exc_type, exc_value, traceback) + self.bc.__exit__(exc_type, exc_value, traceback) + self.ports._clear() + + +class ModuleLikeType(type): + """ModuleLikeType is a metaclass for Module and other things which look like + modules (e.g. ServiceImplementations). A metaclass is nice since it gets run + on each class (including subclasses), so has the ability to modify it. This is + in contrast to a decorator which gets run once. It also has the advantage of + being able to pretty easily extend `Module` or create an entirely new + `ModuleLike` hierarchy. + + This metaclass essentially just kicks the brunt of the work over to a + specified `ModuleLikeBuilder`, which can -- unlike metaclasses -- have state. + Presumably, the usual thing is to store all of this state in the class itself, + but we need this state to be private. Given that this isn't possible in + Python, a single '_' variable is as small a surface area as we can get.""" + + def __init__(cls, name, bases, dct: Dict): + super(ModuleLikeType, cls).__init__(name, bases, dct) + cls._builder = cls.BuilderType(cls, dct, get_user_loc()) + cls._builder.go() + + +class ModuleBuilder(ModuleLikeBuilderBase): + """Defines how a `Module` gets built. Extend the base class and customize.""" @property def circt_mod(self): + """Get the raw CIRCT operation for the module definition. DO NOT store the + returned value!!! It needs to get reaped after the current action (e.g. + instantiation, generation). Memory safety when interacting with native code + can be painful.""" + from .system import System - sys = System.current() - return sys._op_cache.get_circt_mod(self) - - def instantiate(self, instance_name: str, inputs: dict, appid: AppID, loc): - """Create a instance op.""" - if self.instantiate_cb is not None: - ret = self.instantiate_cb(self, instance_name, inputs, appid, loc) - elif self.extern_name is None: - ret = self.circt_mod.instantiate(instance_name, **inputs, loc=loc) - else: - ret = self.circt_mod.instantiate(instance_name, - **inputs, - parameters=self.parameters, - loc=loc) - if appid is not None: - ret.operation.attributes[AppID.AttributeName] = appid._appid + sys: System = System.current() + ret = sys._op_cache.get_circt_mod(self) + if ret is None: + return sys._create_circt_mod(self) return ret - def generate(self, **kwargs): + def create_op(self, sys, symbol): + """Callback for creating a module op.""" + + if len(self.generators) > 0: + # If this Module has a generator, it's a real module. + return msft.MSFTModuleOp( + symbol, + self.inputs, + self.outputs, + self.parameters if hasattr(self, "parameters") else None, + loc=self.loc, + ip=sys._get_ip()) + + # Modules without generators are implicitly considered to be external. + if self.parameters is None: + paramdecl_list = [] + else: + paramdecl_list = [ + hw.ParamDeclAttr.get_nodefault(i.name, i.attr.type) + for i in self.parameters + ] + return msft.MSFTModuleExternOp( + symbol, + self.inputs, + self.outputs, + parameters=paramdecl_list, + attributes={"verilogName": ir.StringAttr.get(self.name)}, + loc=self.loc, + ip=sys._get_ip()) + + def instantiate(self, module_inst, instance_name: str, **inputs): + """"Instantiate this Module. Check that the input types match expectations.""" + + from .circt.dialects import _hw_ops_ext as hwext + input_lookup = { + name: (idx, ptype) for idx, (name, ptype) in enumerate(self.inputs) + } + input_values: List[Optional[Signal]] = [None] * len(self.inputs) + + for name, signal in inputs.items(): + if name not in input_lookup: + raise PortError(f"Input port {name} not found in module") + idx, ptype = input_lookup[name] + if isinstance(signal, Signal): + # If the input is a signal, the types must match. + if ptype != signal.type: + raise PortError( + f"Input port {name} expected type {ptype}, not {signal.type}") + else: + # If it's not a signal, assume the user wants to specify a constant and + # try to convert it to a hardware constant. + signal = _obj_to_value(signal, ptype) + input_values[idx] = signal + del input_lookup[name] + + if len(input_lookup) > 0: + missing = ", ".join(list(input_lookup.keys())) + raise ValueError(f"Missing input signals for ports: {missing}") + + circt_mod = self.circt_mod + parameters = None + # If this is a parameterized external module, the parameters must be + # supplied. + if len(self.generators) == 0 and self.parameters is not None: + parameters = ir.ArrayAttr.get( + hwext.create_parameters(self.parameters, circt_mod)) + inst = msft.InstanceOp(circt_mod.type.results, + instance_name, + ir.FlatSymbolRefAttr.get( + ir.StringAttr( + circt_mod.attributes["sym_name"]).value), + [sig.value for sig in input_values], + parameters=parameters, + loc=get_user_loc()) + inst.verify() + return inst + + def generate(self): """Fill in (generate) this module. Only supports a single generator currently.""" assert len(self.generators) == 1 - for g in self.generators.values(): - return self.generator_cb(g, self, **kwargs) + g: Generator = list(self.generators.values())[0] + + entry_block = self.circt_mod.add_entry_block() + ports = self.generator_port_proxy(entry_block.arguments, self) + with self.GeneratorCtxt(self, ports, entry_block, g.loc): + outputs = g.gen_func(ports) + if outputs is not None: + raise ValueError("Generators must not return a value") + + ports._check_unconnected_outputs() + msft.OutputOp([o.value for o in ports._output_values]) + + +class Module(metaclass=ModuleLikeType): + """Subclass this class to define a regular PyCDE or external module. To define + a module in PyCDE, supply a `@generator` method. To create an external module, + don't. In either case, a list of ports is required. + + A few important notes: + - If your subclass overrides the constructor, it MUST call the parent + constructor AND pass through all of the input port signals to said parent + constructor. Using kwargs (e.g. **inputs) is the easiest way to fulfill this + requirement. + - If you have a @generator, you MUST NOT hold on to, store, or otherwise leak + the first argument (i.e. self) beyond the function return. It is a special + instance constructed exclusively for the generator. + """ - def print(self, out): - print(f"", - file=out) - - -# Set an input to no_connect to indicate not to connect it. Only valid for -# external module inputs. -no_connect = object() - - -def module(func_or_class): - """Decorator to signal that a class should be treated as a module or a - function should be treated as a module parameterization function. In the - latter case, the function must return a python class to be treated as the - parameterized module.""" - generate_cb = func_or_class.generator_cb if hasattr( - func_or_class, "generator_cb") else generate_msft_module_op - create_cb = func_or_class.create_cb if hasattr( - func_or_class, "create_cb") else create_msft_module_op - if inspect.isclass(func_or_class): - # If it's just a module class, we should wrap it immediately - return _module_base(func_or_class, - None, - generator_cb=generate_cb, - create_cb=create_cb) - elif inspect.isfunction(func_or_class): - return _parameterized_module(func_or_class, - None, - generator_cb=generate_cb, - create_cb=create_cb) - raise TypeError( - "@module decorator must be on class or parameterization function") + BuilderType = ModuleBuilder + def __init__(self, instance_name: str = None, appid: AppID = None, **inputs): + """Create an instance of this module. Instance namd and appid are optional. + All inputs must be specified. If a signal has not been produced yet, use the + `Wire` construct and assign the signal to that wire later on.""" -def _get_module_cache_key(func, - params) -> Tuple[builtins.function, ir.DictAttr]: - """The "module" cache is specifically for parameterized modules. It maps the - module parameterization function AND parameter values to the class which was - generated by a previous call to said module parameterization function.""" - if not isinstance(params, ir.DictAttr): - params = _obj_to_attribute(params) - return (func, ir.Attribute(params)) + if instance_name is None: + instance_name = self.__class__.__name__ + instance_name = _BlockContext.current().uniquify_symbol(instance_name) + self.inst = self._builder.instantiate(self, instance_name, **inputs) + if appid is not None: + self.inst.operation.attributes[AppID.AttributeName] = appid._appid + @classmethod + def print(cls, out=sys.stdout): + cls._builder.print(out) -class _parameterized_module: - """When the @module decorator detects that it is decorating a function, use - this class to wrap it.""" - mod = None +class modparams: + """Decorate a function to indicate that it is returning a Module which is + parameterized by this function. Arguments to this class MUST be convertible to + a recognizable constant. Ideally, they would be simple since (by default) they + will be turned into strings and appended to the module name in the resulting + RTL. Arguments with underscore prefixes are ignored and thus exempt from the + previous requirement.""" + func = None - extern_mod = None # When the decorator is attached, this runs. - def __init__(self, - func: builtins.function, - extern_name, - create_cb: builtins.function, - generator_cb: builtins.function = None): - self.extern_name = extern_name + def __init__(self, func: builtins.function): # If it's a module parameterization function, inspect the arguments to # ensure sanity. @@ -336,8 +507,6 @@ def __init__(self, raise TypeError("Module parameter definitions cannot have **kwargs") if param.kind == param.VAR_POSITIONAL: raise TypeError("Module parameter definitions cannot have *args") - self.create_cb = create_cb - self.generator_cb = generator_cb # This function gets executed in two situations: # - In the case of a module function parameterizer, it is called when the @@ -362,296 +531,61 @@ def __call__(self, *args, **kwargs): return _MODULE_CACHE[cache_key] cls = self.func(*args, **kwargs) - if cls is None: - raise ValueError("Parameterization function must return module class") - - mod = _module_base(cls, - self.extern_name, - params=params, - create_cb=self.create_cb, - generator_cb=self.generator_cb) - _MODULE_CACHE[cache_key] = mod - return mod - - -def externmodule(to_be_wrapped, extern_name=None): - """Wrap an externally implemented module. If no name given in the decorator - argument, use the class name.""" - if isinstance(to_be_wrapped, str): - return lambda cls, extern_name=to_be_wrapped: externmodule(cls, extern_name) - - if extern_name is None: - extern_name = to_be_wrapped.__name__ - if inspect.isclass(to_be_wrapped): - # If it's just a module class, we should wrap it immediately - return _module_base(to_be_wrapped, - extern_name, - create_cb=create_msft_module_extern_op) - return _parameterized_module(to_be_wrapped, - extern_name, - create_cb=create_msft_module_extern_op) + if not issubclass(cls, Module): + raise ValueError("Parameterization function must return Module class") + cls._builder.parameters = cache_key[1] + _MODULE_CACHE[cache_key] = cls + return cls -def import_hw_module(hw_module: hw.HWModuleOp): - # Get the module name to use in the generated class and as the external name. - name = ir.StringAttr(hw_module.name).value - - # Collect input and output ports as named Inputs and Outputs. - ports = {} - for input_name, block_arg in hw_module.inputs().items(): - ports[input_name] = Input(block_arg.type, input_name) - for output_name, output_type in hw_module.outputs().items(): - ports[output_name] = Output(output_type, output_name) - # Use the name and ports to construct a class object like what externmodule - # would wrap. - cls = type(name, (object,), ports) +class ImportedModSpec(ModuleBuilder): + """Specialization to support imported CIRCT modules.""" # Creation callback that just moves the already build module into the System's # ModuleOp and returns it. - def create_cb(sys: "System", mod: _SpecializedModule, symbol: str): + def create_op(self, sys, symbol: str): + hw_module = self.modcls.hw_module + # TODO: deal with symbolrefs to this (potentially renamed) module symbol. sys.mod.body.append(hw_module) # Need to clear out the reference to ourselves so that we can release the # raw reference to `hw_module`. It's safe to do so since unlike true PyCDE # modules, this can only be run once during the import_mlir. - mod.create_cb = None + self.modcls.hw_module = None return hw_module - # Hand off the class, external name, and create callback to _module_base. - return _module_base(cls, name, create_cb) - - -def _module_base(cls, - extern_name: str, - create_cb: Optional[builtins.function] = None, - generator_cb: Optional[builtins.function] = None, - instantiate_cb: Optional[builtins.function] = None, - params={}): - """Wrap a class, making it a PyCDE module.""" - - class mod(cls): - - def __init__(self, - *args, - appid: AppID = None, - partition: DesignPartition = None, - **kwargs): - """Scan the class and eventually instance for Input/Output members and - treat the inputs as operands and outputs as results.""" - # Ensure the module has been created. - mod._pycde_mod.create() - # Get the user callsite. - loc = get_user_loc() - - inputs = { - name: kwargs[name] - for (name, _) in mod._pycde_mod.input_ports - if name in kwargs - } - pass_up_kwargs = {n: v for (n, v) in kwargs.items() if n not in inputs} - if len(pass_up_kwargs) > 0: - init_sig = inspect.signature(cls.__init__) - if not any( - [x == inspect.Parameter.VAR_KEYWORD for x in init_sig.parameters]): - raise ValueError("Module constructor doesn't have a **kwargs" - " parameter, so the following are likely inputs" - " which don't have a port: " + - ",".join(pass_up_kwargs.keys())) - cls.__init__(self, *args, **pass_up_kwargs) - - # Build a list of operand values for the operation we're gonna create. - self.backedges: dict[str:BackedgeBuilder.Edge] = {} - for (idx, (name, type)) in enumerate(mod._pycde_mod.input_ports): - if name in inputs: - input = inputs[name] - if input == no_connect: - if extern_name is None: - raise ConnectionError( - "`no_connect` is only valid on extern module ports") - else: - value = create_const_zero(type) - else: - value = _obj_to_value(input, type) - else: - backedge = BackedgeBuilder.current().create(type, - name, - mod._pycde_mod.circt_mod, - loc=loc) - self.backedges[idx] = backedge - value = Value(backedge.result) - inputs[name] = value - - instance_name = cls.__name__ - if "instance_name" in dir(self): - instance_name = self.instance_name - instance_name = _BlockContext.current().uniquify_symbol(instance_name) - # TODO: This is a held Operation*. Add a level of indirection. - self._instantiation = mod._pycde_mod.instantiate(instance_name, - inputs, - appid=appid, - loc=loc) - - op = self._instantiation.operation - if partition: - op.attributes["targetDesignPartition"] = partition._tag - - def output_values(self): - return {outname: getattr(self, outname) for (outname, _) in mod.outputs()} - - @staticmethod - def print(out=sys.stdout): - mod._pycde_mod.print(out) - print() - - @staticmethod - def inputs() -> list[(str, ir.Type)]: - """Return the list of input ports.""" - return mod._pycde_mod.input_ports - - @staticmethod - def outputs() -> list[(str, ir.Type)]: - """Return the list of input ports.""" - return mod._pycde_mod.output_ports - - mod.__qualname__ = cls.__qualname__ - mod.__name__ = cls.__name__ - mod.__module__ = cls.__module__ - mod._pycde_mod = _SpecializedModule(cls, - mod, - params, - extern_name, - create_cb=create_cb, - generator_cb=generator_cb, - instantiate_cb=instantiate_cb) - return mod - - -_current_block_context = ContextVar("current_block_context") - - -class _BlockContext: - """Bookkeeping for a scope.""" - - def __init__(self): - self.symbols: set[str] = set() - - @staticmethod - def current() -> _BlockContext: - """Get the top-most context in the stack created by `with - _BlockContext()`.""" - bb = _current_block_context.get(None) - assert bb is not None - return bb + def instantiate(self, module_inst, instance_name: str, **inputs): + inst = self.circt_mod.instantiate( + instance_name, + **inputs, + parameters={} if self.parameters is None else self.parameters, + loc=get_user_loc()) + inst.operation.verify() + return inst.operation - def __enter__(self): - self._old_system_token = _current_block_context.set(self) - - def __exit__(self, exc_type, exc_value, traceback): - if exc_value is not None: - return - _current_block_context.reset(self._old_system_token) - - def uniquify_symbol(self, sym: str) -> str: - """Create a unique symbol and add it to the cache. If it is to be preserved, - the caller must use it as the symbol on a top-level op.""" - ctr = 0 - ret = sym - while ret in self.symbols: - ctr += 1 - ret = sym + "_" + str(ctr) - self.symbols.add(ret) - return ret - - -class Generator: - """ - Represents a generator. Stores the generate function and location of - the generate call. Generator objects are passed to module-specific generator - object handlers. - """ - - def __init__(self, gen_func): - self.gen_func = gen_func - self.loc = get_user_loc() - - -def generator(func): - """Decorator for generation functions.""" - return Generator(func) +def import_hw_module(hw_module: hw.HWModuleOp): + """Import a CIRCT module into PyCDE. Returns a standard Module subclass which + operates just like an external PyCDE module. -class _GeneratorPortAccess: - """Get the input ports.""" + For now, the imported module name MUST NOT conflict with any other modules.""" - __slots__ = ["_mod", "_output_values", "_block_args"] + # Get the module name to use in the generated class and as the external name. + name = ir.StringAttr(hw_module.name).value - def __init__(self, mod: _SpecializedModule, block_args): - self._mod = mod - self._output_values: dict[str, Value] = {} - self._block_args = block_args + # Collect input and output ports as named Inputs and Outputs. + modattrs = {} + for input_name, block_arg in hw_module.inputs().items(): + modattrs[input_name] = Input(block_arg.type, input_name) + for output_name, output_type in hw_module.outputs().items(): + modattrs[output_name] = Output(output_type, output_name) + modattrs["BuilderType"] = ImportedModSpec + modattrs["hw_module"] = hw_module - # Support attribute access to block arguments by name - def __getattr__(self, name): - if name in self._mod.input_port_lookup: - idx = self._mod.input_port_lookup[name] - val = self._block_args[idx] - if name in self._mod.clock_ports: - return ClockSignal(val, ClockType()) - return Value(val) - if name in self._mod.output_port_lookup: - if name not in self._output_values: - raise ValueError("Must set output value before accessing it") - return self._output_values[name] - - raise AttributeError(f"unknown port name '{name}'") - - def __setattr__(self, name: str, value) -> None: - if name in _GeneratorPortAccess.__slots__: - super().__setattr__(name, value) - return + # Use the name and ports to construct a class object like what externmodule + # would wrap. + cls = type(name, (Module,), modattrs) - if name not in self._mod.output_port_lookup: - raise ValueError(f"Cannot find output port '{name}'") - if name in self._output_values: - raise ValueError(f"Cannot set output '{name}' twice") - - from .behavioral import If - if If.current() is not None: - raise ValueError(f"Cannot set output '{name}' inside an if block") - - output_port = self._mod.output_ports[self._mod.output_port_lookup[name]] - output_port_type = output_port[1] - if not isinstance(value, Signal): - value = _obj_to_value(value, output_port_type) - if value.type != output_port_type: - if value.type == output_port_type.strip: - value = output_port_type.wrap(value) - else: - raise ValueError("Types do not match. Output port type: " - f" '{output_port_type}'. Value type: '{value.type}'") - self._output_values[name] = value - - def set_all_ports(self, port_values: dict[str, Value]): - """Set all of the output values in a portname -> value dict.""" - for (name, value) in port_values.items(): - self.__setattr__(name, value) - - def check_unconnected_outputs(self): - ports_unconnected = self._mod.output_port_lookup.keys() - \ - self._output_values.keys() - if len(ports_unconnected) > 0: - raise ValueError("Generator did not connect all output ports: " + - ", ".join(ports_unconnected)) - - -class DesignPartition: - """Construct a design partition "module" for entities to target.""" - - def __init__(self, name: str): - sym_name = ir.StringAttr.get(name) - dp = msft.DesignPartitionOp(sym_name, sym_name) - parent_mod = dp.operation.parent.attributes["sym_name"] - # TODO: Add SymbolRefAttr to mlir.ir - self._tag = ir.Attribute.parse(f"@{parent_mod}::@{sym_name}") + return cls diff --git a/frontends/PyCDE/src/system.py b/frontends/PyCDE/src/system.py index 7e854e0c271f..1916b0f18b5b 100644 --- a/frontends/PyCDE/src/system.py +++ b/frontends/PyCDE/src/system.py @@ -8,7 +8,7 @@ PhysicalRegion) from .common import _PyProxy -from .module import _SpecializedModule +from .module import Module, ModuleLikeType, ModuleLikeBuilderBase from .pycde_types import types from .instance import Instance, InstanceHierarchyRoot @@ -23,7 +23,7 @@ import os import pathlib import sys -from typing import Any, Callable, Dict, List, Set, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union _current_system = ContextVar("current_pycde_system") @@ -54,10 +54,11 @@ class System: """ def __init__(self, - top_modules: Union[list, _SpecializedModule], + top_modules: Union[list, Module], name: str = "PyCDESystem", output_directory: str = None, sw_api_langs: List[str] = None): + from .module import Module self.passed = False self.mod = ir.Module.create() if isinstance(top_modules, Iterable): @@ -69,8 +70,7 @@ def __init__(self, self._generate_queue = [] # _instance_roots indexed by (module, instance_name). - self._instance_roots: dict[(_SpecializedModule, str), - InstanceHierarchyRoot] = {} + self._instance_roots: dict[(Module, str), InstanceHierarchyRoot] = {} self._placedb: PlacementDB = None @@ -88,7 +88,7 @@ def __init__(self, self.hw_output_dir.mkdir(exist_ok=True) with self: - [m._pycde_mod.create() for m in self.top_modules] + [m._builder.circt_mod for m in self.top_modules] def add_packaging_step(self, func: Callable): self.packaging_funcs.append(func) @@ -108,15 +108,6 @@ def _get_ip(self): def set_debug(): ir._GlobalDebug.flag = True - def import_modules(self, modules: list[_SpecializedModule]): - # Call the imported modules' `create` methods to import the IR into the - # PyCDE System ModuleOp. Also add them to the list of top-level modules so - # later emission stages know about them. - with self: - for module in modules: - module._pycde_mod.create() - self.top_modules.append(module) - # TODO: Ideally, we'd be able to run the std-to-handshake lowering passes in # pycde. As of now, however, the cf/memref/arith dialects are not registered # so the assembly can't be loaded. The right way to do this is to have pycde @@ -153,7 +144,7 @@ def import_mlir(self, module, lowering=None): if isinstance(op, hw.HWModuleOp): from .module import import_hw_module im = import_hw_module(op) - self._create_circt_mod(im._pycde_mod) + self._create_circt_mod(im._builder) ret[ir.StringAttr(op.name).value] = im elif isinstance(op, esi.RandomAccessMemoryDeclOp): from .esi import _import_ram_decl @@ -175,27 +166,27 @@ def create_entity_extern(self, tag: str, metadata=""): entity_extern = EntityExtern(tag, metadata) return entity_extern - def _create_circt_mod(self, spec_mod: _SpecializedModule): + def _create_circt_mod(self, builder: ModuleLikeBuilderBase): """Wrapper for a callback (which actually builds the CIRCT op) which controls all the bookkeeping around CIRCT module ops.""" - (symbol, install_func) = self._op_cache.create_symbol(spec_mod) + (symbol, install_func) = self._op_cache.create_symbol(builder) if symbol is None: return # Build the correct op. - op = spec_mod.create_cb(self, spec_mod, symbol) + op = builder.create_op(self, symbol) # Install the op in the cache. install_func(op) # Add to the generation queue if the module has a generator callback. - if hasattr(spec_mod, 'generator_cb') and spec_mod.generator_cb is not None: - assert callable(spec_mod.generator_cb) - self._generate_queue.append(spec_mod) - file_name = spec_mod.modcls.__name__ + ".sv" + if len(builder.generators) > 0: + self._generate_queue.append(builder) + file_name = builder.modcls.__name__ + ".sv" outfn = self.output_directory / file_name self.files.add(outfn) self.mod_files.add(outfn) op.fileName = ir.StringAttr.get(str(file_name)) + return op @staticmethod def current(): @@ -247,7 +238,7 @@ def get_instance(self, mod_cls: object, instance_name: str = None) -> InstanceHierarchyRoot: assert len(self._generate_queue) == 0, "Ungenerated modules left" - mod = mod_cls._pycde_mod + mod = mod_cls._builder key = (mod, instance_name) if key not in self._instance_roots: self._instance_roots[key] = InstanceHierarchyRoot(mod, instance_name, @@ -448,16 +439,19 @@ def get_symbol_pyproxy(self, symbol): def get_pyproxy_symbol(self, spec_mod) -> str: """Get the symbol for a module or its associated _PyProxy.""" - if not isinstance(spec_mod, _SpecializedModule): - if hasattr(spec_mod, "_pycde_mod"): - spec_mod = spec_mod._pycde_mod + if not isinstance(spec_mod, Module): + if isinstance(spec_mod, ModuleLikeType): + spec_mod = spec_mod._builder if spec_mod not in self._pyproxy_symbols: return None return self._pyproxy_symbols[spec_mod] - def get_circt_mod(self, spec_mod: _SpecializedModule) -> ir.Operation: + def get_circt_mod(self, spec_mod: Module) -> Optional[ir.Operation]: """Get the CIRCT module op for a PyCDE module.""" - return self.symbols[self.get_pyproxy_symbol(spec_mod)] + sym = self.get_pyproxy_symbol(spec_mod) + if sym in self.symbols: + return self.symbols[sym] + return None def _build_instance_hier_cache(self): """If the instance hierarchy cache doesn't exist, build it.""" @@ -535,8 +529,8 @@ def get_or_create_inst_from_op(self, op: ir.OpView) -> pi.Instance: raise TypeError( "Can only resolve from InstanceHierarchyOp or DynamicInstanceOp") - def get_sym_ops_in_module( - self, module: _SpecializedModule) -> Dict[ir.Attribute, ir.Operation]: + def get_sym_ops_in_module(self, + module: Module) -> Dict[ir.Attribute, ir.Operation]: """Look into the IR inside 'module' for any ops which have a `sym_name` attribute. Cached.""" diff --git a/frontends/PyCDE/src/testing.py b/frontends/PyCDE/src/testing.py index 1a105cb5a60d..e099613c77f7 100644 --- a/frontends/PyCDE/src/testing.py +++ b/frontends/PyCDE/src/testing.py @@ -1,10 +1,9 @@ -from pycde import System, module +from pycde import System, Module import builtins import inspect from pathlib import Path import subprocess -import inspect import re import os @@ -23,11 +22,14 @@ def unittestmodule(generate=True, """ def testmodule_inner(func_or_class): - mod = module(func_or_class) - # Apply any provided kwargs if this was a function. if inspect.isfunction(func_or_class): - mod = mod(**kwargs) + mod = func_or_class(**kwargs) + elif inspect.isclass(func_or_class) and issubclass(func_or_class, Module): + mod = func_or_class + else: + raise AssertionError("unittest() must decorate a function or" + " a Module subclass") # Add the module to global scope in case it's referenced within the # module generator functions diff --git a/frontends/PyCDE/test/behavioral.py b/frontends/PyCDE/test/behavioral.py index 5352ee9f35aa..ab01ae551b73 100644 --- a/frontends/PyCDE/test/behavioral.py +++ b/frontends/PyCDE/test/behavioral.py @@ -1,6 +1,6 @@ # RUN: %PYTHON% py-split-input-file.py %s | FileCheck %s -from pycde import generator, types, Input, Output +from pycde import generator, types, Module, Input, Output from pycde.behavioral import If, Else, EndIf from pycde.testing import unittestmodule @@ -19,7 +19,7 @@ @unittestmodule() -class IfNestedTest: +class IfNestedTest(Module): a = Input(types.ui8) b = Input(types.ui8) cond = Input(types.i1) @@ -60,7 +60,7 @@ def build(ports): @unittestmodule() -class IfDefaultTest: +class IfDefaultTest(Module): a = Input(types.ui8) b = Input(types.ui8) cond = Input(types.i1) @@ -83,7 +83,7 @@ def build(ports): @unittestmodule() -class IfMismatchErrorTest: +class IfMismatchErrorTest(Module): cond = Input(types.i1) a = Input(types.ui8) b = Input(types.ui4) @@ -104,7 +104,7 @@ def build(ports): @unittestmodule() -class IfMismatchEndIfTest: +class IfMismatchEndIfTest(Module): @generator def build(ports): @@ -116,7 +116,7 @@ def build(ports): @unittestmodule() -class IfCondErrorTest: +class IfCondErrorTest(Module): cond = Input(types.i2) @generator diff --git a/frontends/PyCDE/test/cocotb_testbench.py b/frontends/PyCDE/test/cocotb_testbench.py index c1ddc0375ad2..79a97a9cea9d 100644 --- a/frontends/PyCDE/test/cocotb_testbench.py +++ b/frontends/PyCDE/test/cocotb_testbench.py @@ -18,7 +18,7 @@ @module def make_adder(width): - class Adder: + class Adder(Module): in1 = Input(types.int(width)) in2 = Input(types.int(width)) out = Output(types.int(width)) @@ -31,7 +31,7 @@ def build(ports): @module -class RegAdd: +class RegAdd(Module): rst = Input(types.i1) clk = Clock() in1 = Input(types.i16) @@ -104,14 +104,14 @@ async def inc_test(ports): @externmodule("adder") -class ExternAdder: +class ExternAdder(Module): in1 = Input(types.i16) in2 = Input(types.i16) out = Output(types.i16) @module -class RegAdd: +class RegAdd(Module): rst = Input(types.i1) clk = Clock() in1 = Input(types.i16) diff --git a/frontends/PyCDE/test/compreg.py b/frontends/PyCDE/test/compreg.py index a9d52e3facb4..0db51bd07558 100644 --- a/frontends/PyCDE/test/compreg.py +++ b/frontends/PyCDE/test/compreg.py @@ -4,7 +4,7 @@ # RUN: FileCheck %s --input-file %t/hw/CompReg.tcl --check-prefix TCL import pycde -from pycde import types, module, AppID, Clock, Input, Output +from pycde import types, Module, AppID, Clock, Input, Output from pycde.devicedb import LocationVector from pycde.module import generator @@ -25,8 +25,7 @@ # CHECK: end -@module -class CompReg: +class CompReg(Module): clk = Clock() rst = Input(types.i1) input = Input(types.i8) diff --git a/frontends/PyCDE/test/errors.py b/frontends/PyCDE/test/errors.py index 30a3251587bd..7a7b13d74804 100644 --- a/frontends/PyCDE/test/errors.py +++ b/frontends/PyCDE/test/errors.py @@ -1,12 +1,12 @@ # RUN: %PYTHON% py-split-input-file.py %s | FileCheck %s from pycde import Clock, Input, types, System -from pycde.module import AppID, externmodule, generator, module +from pycde.module import AppID, generator, Module, modparams from pycde.testing import unittestmodule # CHECK: TypeError: Module parameter definitions cannot have *args -@externmodule +@modparams def foo(*args): pass @@ -15,7 +15,7 @@ def foo(*args): # CHECK: TypeError: Module parameter definitions cannot have **kwargs -@externmodule +@modparams def bar(**kwargs): pass @@ -24,7 +24,7 @@ def bar(**kwargs): @unittestmodule() -class ClkError: +class ClkError(Module): a = Input(types.i32) @generator @@ -37,7 +37,7 @@ def build(ports): @unittestmodule() -class AppIDError: +class AppIDError(Module): @generator def build(ports): @@ -49,8 +49,7 @@ def build(ports): # ----- -@module -class Test: +class Test(Module): clk = Clock() x = Input(types.i32) @@ -70,7 +69,7 @@ def build(ports): @unittestmodule() -class OperatorError: +class OperatorError(Module): a = Input(types.i32) b = Input(types.si32) @@ -84,7 +83,7 @@ def build(ports): @unittestmodule() -class OperatorError2: +class OperatorError2(Module): a = Input(types.i32) b = Input(types.si32) diff --git a/frontends/PyCDE/test/esi.py b/frontends/PyCDE/test/esi.py index 0aaa1dcee837..ebd0074a4a6b 100644 --- a/frontends/PyCDE/test/esi.py +++ b/frontends/PyCDE/test/esi.py @@ -1,7 +1,7 @@ # RUN: rm -rf %t # RUN: %PYTHON% %s %t 2>&1 | FileCheck %s -from pycde import (Clock, Input, InputChannel, OutputChannel, module, generator, +from pycde import (Clock, Input, InputChannel, OutputChannel, Module, generator, types) from pycde import esi from pycde.common import Output @@ -19,8 +19,7 @@ class HostComms: to_client_type=types.i32) -@module -class Producer: +class Producer(Module): clk = Input(types.i1) int_out = OutputChannel(types.i32) @@ -30,8 +29,7 @@ def construct(ports): ports.int_out = chan -@module -class Consumer: +class Consumer(Module): clk = Input(types.i1) int_in = InputChannel(types.i32) @@ -57,7 +55,7 @@ def construct(ports): @unittestmodule(print=True) -class LoopbackTop: +class LoopbackTop(Module): clk = Clock(types.i1) rst = Input(types.i1) @@ -77,14 +75,14 @@ def construct(ports): # CHECK: %chanOutput, %ready = esi.wrap.vr %1, %valid : i16 # CHECK: msft.output @unittestmodule(print=True) -class LoopbackInOutTop: +class LoopbackInOutTop(Module): clk = Clock(types.i1) rst = Input(types.i1) @generator - def construct(ports): + def construct(self): # Use Cosim to implement the standard 'HostComms' service. - esi.Cosim(HostComms, ports.clk, ports.rst) + esi.Cosim(HostComms, self.clk, self.rst) loopback = Wire(types.channel(types.i16)) from_host = HostComms.req_resp(loopback, "loopback_inout") @@ -96,8 +94,7 @@ def construct(ports): loopback.assign(data_chan) -@esi.ServiceImplementation(HostComms) -class MultiplexerService: +class MultiplexerService(esi.ServiceImplementation): clk = Clock() rst = Input(types.i1) @@ -110,21 +107,21 @@ class MultiplexerService: trunk_out_ready = Input(types.i1) @generator - def generate(ports, channels): + def generate(self, channels): input_reqs = channels.to_server_reqs if len(input_reqs) > 1: raise Exception("Multiple to_server requests not supported") - MultiplexerService.unwrap_and_pad(ports, input_reqs[0]) + MultiplexerService.unwrap_and_pad(self, input_reqs[0]) output_reqs = channels.to_client_reqs if len(output_reqs) > 1: raise Exception("Multiple to_client requests not supported") output_req = output_reqs[0] output_chan, ready = MultiplexerService.slice_and_wrap( - ports, output_req.type) + self, output_req.type) output_req.assign(output_chan) - ports.trunk_in_ready = ready + self.trunk_in_ready = ready @staticmethod def slice_and_wrap(ports, channel_type: ChannelType): @@ -158,7 +155,7 @@ def unwrap_and_pad(ports, input_channel: ChannelValue): @unittestmodule(run_passes=True, print_after_passes=True, emit_outputs=True) -class MultiplexerTop: +class MultiplexerTop(Module): clk = Clock(types.i1) rst = Input(types.i1) @@ -171,7 +168,8 @@ class MultiplexerTop: @generator def construct(ports): - m = MultiplexerService(clk=ports.clk, + m = MultiplexerService(HostComms, + clk=ports.clk, rst=ports.rst, trunk_in=ports.trunk_in, trunk_in_valid=ports.trunk_in_valid, diff --git a/frontends/PyCDE/test/good_example.py b/frontends/PyCDE/test/good_example.py index 0909e427f43a..c3e95b8f5f5a 100644 --- a/frontends/PyCDE/test/good_example.py +++ b/frontends/PyCDE/test/good_example.py @@ -4,14 +4,13 @@ # This is intended to be a simple 'tutorial' example. Run it as a test to # ensure that we keep it up to date (ensure it doesn't crash). -from pycde import dim, module, generator, types, Clock, Input, Output +from pycde import dim, generator, types, Clock, Input, Output, Module import pycde import sys -@module -class Mux: +class Mux(Module): clk = Clock() data = Input(dim(8, 14)) sel = Input(types.i4) diff --git a/frontends/PyCDE/test/instances.py b/frontends/PyCDE/test/instances.py index 2f0efff06c40..0d0d4bca69fc 100644 --- a/frontends/PyCDE/test/instances.py +++ b/frontends/PyCDE/test/instances.py @@ -10,16 +10,14 @@ import sys from pycde.instance import InstanceDoesNotExistError, Instance, RegInstance -from pycde.module import AppID +from pycde.module import AppID, Module -@pycde.externmodule -class Nothing: +class Nothing(Module): pass -@pycde.module -class Delay: +class Delay(Module): clk = pycde.Clock() x = pycde.Input(pycde.types.i1) y = pycde.Output(pycde.types.i1) @@ -35,8 +33,7 @@ def construct(mod): print(f"r appid: {r.appid}") -@pycde.module -class UnParameterized: +class UnParameterized(Module): clk = pycde.Clock() x = pycde.Input(pycde.types.i1) y = pycde.Output(pycde.types.i1) @@ -47,8 +44,7 @@ def construct(mod): mod.y = Delay(clk=mod.clk, x=mod.x).y -@pycde.module -class Test: +class Test(Module): clk = pycde.Clock() @pycde.generator @@ -63,9 +59,9 @@ def build(ports): t = pycde.System([Test], name="Test", output_directory=sys.argv[1]) t.generate(["construct"]) t.print() -# CHECK: msft.module @UnParameterized -# CHECK-NOT: msft.module @UnParameterized +# CHECK: Test.print() +# CHECK: UnParameterized.print() print(PhysLocation(PrimitiveType.DSP, 39, 25)) diff --git a/frontends/PyCDE/test/module_naming.py b/frontends/PyCDE/test/module_naming.py index e54e1c6e2306..7b60674c5594 100644 --- a/frontends/PyCDE/test/module_naming.py +++ b/frontends/PyCDE/test/module_naming.py @@ -4,10 +4,10 @@ import pycde.dialects.hw -@pycde.module +@pycde.modparams def Parameterized(param): - class TestModule: + class TestModule(pycde.Module): x = pycde.Input(pycde.types.i1) y = pycde.Output(pycde.types.i1) @@ -18,8 +18,7 @@ def construct(ports): return TestModule -@pycde.module -class UnParameterized: +class UnParameterized(pycde.Module): x = pycde.Input(pycde.types.i1) y = pycde.Output(pycde.types.i1) @@ -28,8 +27,7 @@ def construct(ports): ports.y = ports.x -@pycde.module -class Test: +class Test(pycde.Module): inputs = [] outputs = [] diff --git a/frontends/PyCDE/test/muxing.py b/frontends/PyCDE/test/muxing.py index b03eb4bbbc56..157eaeaa698a 100644 --- a/frontends/PyCDE/test/muxing.py +++ b/frontends/PyCDE/test/muxing.py @@ -1,6 +1,6 @@ # RUN: %PYTHON% py-split-input-file.py %s | FileCheck %s -from pycde import generator, dim, Clock, Input, Output, Value, types +from pycde import generator, dim, Clock, Input, Output, Module, Value, types from pycde.constructs import Mux from pycde.testing import unittestmodule @@ -35,7 +35,7 @@ def array_from_tuple(*input): @unittestmodule() -class ComplexMux: +class ComplexMux(Module): Clk = Clock() In = Input(dim(3, 4, 5)) @@ -71,7 +71,7 @@ def create(ports): @unittestmodule() -class Slicing: +class Slicing(Module): In = Input(dim(8, 4, 5)) Sel8 = Input(types.i8) Sel2 = Input(types.i2) diff --git a/frontends/PyCDE/test/polynomial.py b/frontends/PyCDE/test/polynomial.py index 3e32c7693aac..a0f954b61eec 100755 --- a/frontends/PyCDE/test/polynomial.py +++ b/frontends/PyCDE/test/polynomial.py @@ -5,30 +5,27 @@ from __future__ import annotations import pycde -from pycde import (AppID, Input, Output, module, externmodule, generator, types) +from pycde import (AppID, Input, Output, generator, types) +from pycde.module import Module, modparams from pycde.dialects import comb, hw -from pycde.circt.support import connect import sys -@module +@modparams def PolynomialCompute(coefficients: Coefficients): - class PolynomialCompute: + class PolynomialCompute(Module): """Module to compute ax^3 + bx^2 + cx + d for design-time coefficients""" + module_name = f"PolyComputeForCoeff_{coefficients.coeff}" # Evaluate polynomial for 'x'. x = Input(types.i32) y = Output(types.int(8 * 4)) - def __init__(self, name: str): + def __init__(self, name: str, **kwargs): """coefficients is in 'd' -> 'a' order.""" - self.instance_name = name - - @staticmethod - def get_module_name(): - return f"PolyComputeForCoeff_{coefficients.coeff}" + super().__init__(instance_name=name, **kwargs) @generator def construct(mod): @@ -57,22 +54,23 @@ def construct(mod): return PolynomialCompute -@externmodule("supercooldevice") -class CoolPolynomialCompute: +class CoolPolynomialCompute(Module): + module_name = "supercooldevice" x = Input(types.i32) y = Output(types.i32) - def __init__(self, coefficients): + def __init__(self, coefficients, **inputs): + super().__init__(**inputs) self.coefficients = coefficients -@externmodule("parameterized_extern") +@modparams def ExternWithParams(a, b): typedef1 = types.struct({"a": types.i1}, "exTypedef") - class M: - pass + class M(Module): + module_name = "parameterized_extern" return M @@ -83,28 +81,26 @@ def __init__(self, coeff): self.coeff = coeff -@module -class PolynomialSystem: +class PolynomialSystem(Module): y = Output(types.i32) @generator - def construct(ports): + def construct(self): i32 = types.i32 x = hw.ConstantOp(i32, 23) poly = PolynomialCompute(Coefficients([62, 42, 6]))("example", - appid=AppID("poly", 0)) - connect(poly.x, x) + appid=AppID("poly", 0), + x=x) PolynomialCompute(coefficients=Coefficients([62, 42, 6]))("example2", x=poly.y) PolynomialCompute(Coefficients([1, 2, 3, 4, 5]))("example2", x=poly.y) - cp = CoolPolynomialCompute([4, 42]) - cp.x.connect(23) + CoolPolynomialCompute([4, 42], x=23) m = ExternWithParams(8, 3)() m.name = "pexternInst" - ports.y = poly.y + self.y = poly.y poly = pycde.System([PolynomialSystem], @@ -132,6 +128,7 @@ def construct(ports): print("Generating rest...") poly.generate() +poly.print() print("=== Post-generate IR...") poly.run_passes() diff --git a/frontends/PyCDE/test/pycde_values.py b/frontends/PyCDE/test/pycde_values.py index fad7cfbdf2db..40e0d73512f7 100644 --- a/frontends/PyCDE/test/pycde_values.py +++ b/frontends/PyCDE/test/pycde_values.py @@ -1,7 +1,7 @@ # RUN: %PYTHON% %s | FileCheck %s from pycde.dialects import comb, hw -from pycde import dim, generator, types, Input, Output +from pycde import dim, generator, types, Input, Output, Module from pycde.value import And, Or from pycde.testing import unittestmodule @@ -9,7 +9,7 @@ @unittestmodule(SIZE=4) def MyModule(SIZE: int): - class Mod: + class Mod(Module): inp = Input(dim(SIZE)) out = Output(dim(SIZE)) @@ -31,7 +31,7 @@ def construct(mod): # CHECK-LABEL: msft.module @Mod {} (%inp: !hw.array<5xi1>) @unittestmodule() -class Mod: +class Mod(Module): inp = Input(dim(types.i1, 5)) @generator diff --git a/frontends/PyCDE/test/syntactic_sugar.py b/frontends/PyCDE/test/syntactic_sugar.py index 885070f22cdd..86264dae3c65 100644 --- a/frontends/PyCDE/test/syntactic_sugar.py +++ b/frontends/PyCDE/test/syntactic_sugar.py @@ -1,8 +1,6 @@ # RUN: %PYTHON% %s | FileCheck %s -from pycde import (Output, Input, module, generator, types, dim, System, - no_connect) -from pycde.module import externmodule +from pycde import (Output, Input, generator, types, dim, Module) from pycde.testing import unittestmodule # CHECK-LABEL: msft.module @Top {} () attributes {fileName = "Top.sv"} { @@ -15,9 +13,6 @@ # CHECK: %c7_i12_0 = hw.constant 7 : i12 # CHECK: hw.struct_create (%c7_i12_0) : !hw.typealias<@pycde::@bar, !hw.struct> # CHECK: %Taps.taps = msft.instance @Taps @Taps() : () -> !hw.array<3xi8> -# CHECK: %c0_i4 = hw.constant 0 : i4 -# CHECK: [[ARG0:%.+]] = hw.bitcast %c0_i4 : (i4) -> !hw.array<4xi1> -# CHECK: msft.instance @StupidLegacy @StupidLegacy([[ARG0]]) : (!hw.array<4xi1>) -> () # CHECK: msft.output # CHECK-LABEL: msft.module @Taps {} () -> (taps: !hw.array<3xi8>) attributes {fileName = "Taps.sv"} { # CHECK: %c-53_i8 = hw.constant -53 : i8 @@ -25,11 +20,9 @@ # CHECK: %c23_i8 = hw.constant 23 : i8 # CHECK: [[R0:%.+]] = hw.array_create %c23_i8, %c100_i8, %c-53_i8 : i8 # CHECK: msft.output [[R0]] : !hw.array<3xi8> -# CHECK: msft.module.extern @StupidLegacy(%ignore: !hw.array<4xi1>) attributes {verilogName = "StupidLegacy"} -@module -class Taps: +class Taps(Module): taps = Output(dim(8, 3)) @generator @@ -37,16 +30,11 @@ def build(ports): ports.taps = [203, 100, 23] -@externmodule -class StupidLegacy: - ignore = Input(dim(1, 4)) - - BarType = types.struct({"foo": types.i12}, "bar") @unittestmodule() -class Top: +class Top(Module): @generator def build(_): @@ -57,7 +45,6 @@ def build(_): BarType({"foo": 7}) Taps() - StupidLegacy(ignore=no_connect) # ----- @@ -74,7 +61,7 @@ def build(_): @unittestmodule() -class ComplexPorts: +class ComplexPorts(Module): clk = Input(types.i1) data_in = Input(dim(32, 3)) sel = Input(types.i2) @@ -85,10 +72,8 @@ class ComplexPorts: c = Output(types.i32) @generator - def build(ports): - assert len(ports.data_in) == 3 - ports.set_all_ports({ - 'a': ports.data_in[0].reg(ports.clk).reg(ports.clk), - 'b': ports.data_in[ports.sel], - 'c': ports.struct_data_in.foo[:-4] - }) + def build(self): + assert len(self.data_in) == 3 + self.a = self.data_in[0].reg(self.clk).reg(self.clk) + self.b = self.data_in[self.sel] + self.c = self.struct_data_in.foo[:-4] diff --git a/frontends/PyCDE/test/test_constructs.py b/frontends/PyCDE/test/test_constructs.py index 62c175adf4a6..2e315749732e 100644 --- a/frontends/PyCDE/test/test_constructs.py +++ b/frontends/PyCDE/test/test_constructs.py @@ -1,6 +1,6 @@ # RUN: %PYTHON% %s | FileCheck %s -from pycde import generator, types, dim +from pycde import generator, types, dim, Module from pycde.common import Clock, Input, Output from pycde.constructs import ControlReg, NamedWire, Reg, Wire, SystolicArray from pycde.dialects import comb @@ -21,7 +21,7 @@ @unittestmodule() -class WireAndRegTest: +class WireAndRegTest(Module): In = Input(types.i8) InCE = Input(types.i1) clk = Clock() @@ -63,7 +63,7 @@ def create(ports): # CHECK: %sum__reg1_0_0 = sv.reg sym @sum__reg1 : !hw.inout # CHECK: sv.read_inout %sum__reg1_0_0 : !hw.inout @unittestmodule(print=True, run_passes=True, print_after_passes=True) -class SystolicArrayTest: +class SystolicArrayTest(Module): clk = Input(types.i1) col_data = Input(dim(8, 2)) row_data = Input(dim(8, 3)) @@ -97,7 +97,7 @@ def pe(r, c): # CHECK: [[r6:%.+]] = comb.mux bin [[r2]], %true{{.*}}, [[r5]] # CHECK: msft.output %state @unittestmodule() -class ControlRegTest: +class ControlRegTest(Module): clk = Clock() rst = Input(types.i1) a1 = Input(types.i1) diff --git a/frontends/PyCDE/test/test_constructs_errors.py b/frontends/PyCDE/test/test_constructs_errors.py index 11de4477608a..35992fab86c5 100644 --- a/frontends/PyCDE/test/test_constructs_errors.py +++ b/frontends/PyCDE/test/test_constructs_errors.py @@ -1,13 +1,13 @@ # RUN: %PYTHON% py-split-input-file.py %s | FileCheck %s -from pycde import generator, types +from pycde import generator, types, Module from pycde.common import Clock, Input from pycde.constructs import Reg, Wire from pycde.testing import unittestmodule @unittestmodule() -class WireTypeTest: +class WireTypeTest(Module): In = Input(types.i8) @generator @@ -21,7 +21,7 @@ def create(ports): @unittestmodule() -class WireDoubleAssignTest: +class WireDoubleAssignTest(Module): In = Input(types.i8) @generator @@ -36,7 +36,7 @@ def create(ports): @unittestmodule() -class RegTypeTest: +class RegTypeTest(Module): clk = Clock() In = Input(types.i8) @@ -51,7 +51,7 @@ def create(ports): @unittestmodule() -class RegDoubleAssignTest: +class RegDoubleAssignTest(Module): Clk = Clock() In = Input(types.i8) diff --git a/frontends/PyCDE/test/test_esi_errors.py b/frontends/PyCDE/test/test_esi_errors.py index f5b7182617d8..41a5e0d68014 100644 --- a/frontends/PyCDE/test/test_esi_errors.py +++ b/frontends/PyCDE/test/test_esi_errors.py @@ -1,7 +1,7 @@ # RUN: rm -rf %t # RUN: %PYTHON% py-split-input-file.py %s 2>&1 | FileCheck %s -from pycde import (Clock, Input, InputChannel, OutputChannel, module, generator, +from pycde import (Clock, Input, InputChannel, OutputChannel, Module, generator, types) from pycde import esi from pycde.testing import unittestmodule @@ -13,8 +13,7 @@ class HostComms: from_host = esi.FromServer(types.any) -@module -class Producer: +class Producer(Module): clk = Input(types.i1) int_out = OutputChannel(types.i32) @@ -24,8 +23,7 @@ def construct(ports): ports.int_out = chan -@module -class Consumer: +class Consumer(Module): clk = Input(types.i1) int_in = InputChannel(types.i32) @@ -35,7 +33,7 @@ def construct(ports): @unittestmodule(print=True) -class LoopbackTop: +class LoopbackTop(Module): clk = Clock(types.i1) rst = Input(types.i1) @@ -47,11 +45,13 @@ def construct(ports): esi.Cosim(HostComms, ports.clk, ports.rst) -@esi.ServiceImplementation(HostComms) -class MultiplexerService: +class MultiplexerService(esi.ServiceImplementation): clk = Clock() rst = Input(types.i1) + def __init__(self, **inputs): + super().__init__(HostComms, **inputs) + @generator def generate(ports, channels): @@ -74,7 +74,7 @@ def generate(ports, channels): @unittestmodule(run_passes=True, print_after_passes=True) -class MultiplexerTop: +class MultiplexerTop(Module): clk = Clock(types.i1) rst = Input(types.i1) @@ -90,11 +90,13 @@ def construct(ports): # ----- -@esi.ServiceImplementation(HostComms) -class BrokenService: +class BrokenService(esi.ServiceImplementation): clk = Clock() rst = Input(types.i1) + def __init__(self, **inputs): + super().__init__(HostComms, **inputs) + @generator def generate(ports, channels): return "asdf" @@ -102,7 +104,7 @@ def generate(ports, channels): @unittestmodule(run_passes=True, print_after_passes=True) -class BrokenTop: +class BrokenTop(Module): clk = Clock(types.i1) rst = Input(types.i1) diff --git a/frontends/PyCDE/test/test_fsm.py b/frontends/PyCDE/test/test_fsm.py index 85fa9afbd65b..8067761b92dc 100644 --- a/frontends/PyCDE/test/test_fsm.py +++ b/frontends/PyCDE/test/test_fsm.py @@ -1,6 +1,7 @@ # RUN: %PYTHON% py-split-input-file.py %s | FileCheck %s +# XFAIL: * -from pycde import System, Input, Output, generator +from pycde import System, Input, Output, generator, Module from pycde.dialects import comb from pycde import fsm from pycde.pycde_types import types @@ -35,7 +36,7 @@ class FSM: @unittestmodule() -class FSMUser: +class FSMUser(Module): a = Input(types.i1) b = Input(types.i1) c = Input(types.i1) @@ -184,7 +185,7 @@ def nand(*args): @unittestmodule() -class FSMUser: +class FSMUser(Module): go = Input(types.i1) clk = Input(types.i1) rst = Input(types.i1) diff --git a/frontends/PyCDE/test/test_fsm_errors.py b/frontends/PyCDE/test/test_fsm_errors.py index 408d4582991a..ba4112779938 100644 --- a/frontends/PyCDE/test/test_fsm_errors.py +++ b/frontends/PyCDE/test/test_fsm_errors.py @@ -1,10 +1,8 @@ # RUN: %PYTHON% py-split-input-file.py %s | FileCheck %s from pycde import System, Input, Output, generator -from pycde.dialects import comb from pycde import fsm from pycde.pycde_types import types -from pycde.testing import unittestmodule @fsm.machine() @@ -32,4 +30,4 @@ class FSM: class FSM: a = Input(types.i1) A = fsm.State(initial=True) - B = fsm.State(initial=True) \ No newline at end of file + B = fsm.State(initial=True) diff --git a/frontends/PyCDE/test/test_hwarith.py b/frontends/PyCDE/test/test_hwarith.py index bc724d5e20fd..d740c915919c 100644 --- a/frontends/PyCDE/test/test_hwarith.py +++ b/frontends/PyCDE/test/test_hwarith.py @@ -1,6 +1,6 @@ # RUN: %PYTHON% py-split-input-file.py %s | FileCheck %s -from pycde import Input, Output, generator +from pycde import Input, Output, generator, Module from pycde.testing import unittestmodule from pycde.pycde_types import types @@ -15,7 +15,7 @@ # CHECK-NEXT: %5 = hwarith.mul %in0, %4 {{({sv.namehint = ".*"} )?}}: (si16, si16) -> si32 # CHECK-NEXT: msft.output @unittestmodule(run_passes=True) -class InfixArith: +class InfixArith(Module): in0 = Input(types.si16) in1 = Input(types.ui16) @@ -39,7 +39,7 @@ def construct(ports): # CHECK-NEXT: comb.xor bin %in0, %c-1_i16 {{({sv.namehint = ".*"} )?}}: i16 # CHECK-NEXT: msft.output @unittestmodule(run_passes=True) -class InfixLogic: +class InfixLogic(Module): in0 = Input(types.i16) in1 = Input(types.i16) @@ -59,7 +59,7 @@ def construct(ports): # CHECK-NEXT: %1 = comb.icmp bin ne %in0, %in1 {{({sv.namehint = ".*"} )?}}: i16 # CHECK-NEXT: msft.output @unittestmodule(run_passes=True) -class SignlessInfixComparison: +class SignlessInfixComparison(Module): in0 = Input(types.i16) in1 = Input(types.i16) @@ -77,7 +77,7 @@ def construct(ports): # CHECK-NEXT: %1 = hwarith.icmp ne %in0, %in1 {sv.namehint = "in0_neq_in1"} : ui16, ui16 # CHECK-NEXT: msft.output @unittestmodule(run_passes=False) -class InfixComparison: +class InfixComparison(Module): in0 = Input(types.ui16) in1 = Input(types.ui16) @@ -97,7 +97,7 @@ def construct(ports): # CHECK-NEXT: %3 = hwarith.cast %2 {{({sv.namehint = ".*"} )?}}: (si19) -> i16 # CHECK-NEXT: msft.output %3 {{({sv.namehint = ".*"} )?}}: i16 @unittestmodule(run_passes=True) -class Multiple: +class Multiple(Module): in0 = Input(types.si16) in1 = Input(types.si16) out0 = Output(types.i16) @@ -120,7 +120,7 @@ def construct(ports): # CHECK-NEXT: %6 = hwarith.cast %0 {{({sv.namehint = ".*"} )?}}: (si16) -> si24 # CHECK-NEXT: msft.output @unittestmodule(run_passes=True) -class Casting: +class Casting(Module): in0 = Input(types.i16) @generator @@ -141,7 +141,7 @@ def construct(ports): # CHECK-NEXT: %0 = comb.add %in0, %in1 {{({sv.namehint = ".*"} )?}}: i16 # CHECK-NEXT: hw.output %0 {{({sv.namehint = ".*"} )?}}: i16 @unittestmodule(generate=True, run_passes=True, print_after_passes=True) -class Lowering: +class Lowering(Module): in0 = Input(types.i16) in1 = Input(types.i16) out0 = Output(types.i16) diff --git a/frontends/PyCDE/test/test_import_hw_modules.py b/frontends/PyCDE/test/test_import_hw_modules.py index d34d58b8115c..c6912b3363af 100644 --- a/frontends/PyCDE/test/test_import_hw_modules.py +++ b/frontends/PyCDE/test/test_import_hw_modules.py @@ -1,12 +1,12 @@ # RUN: %PYTHON% py-split-input-file.py %s | FileCheck %s -from pycde.circt.ir import Module +from pycde.circt.ir import Module as IrModule from pycde.circt.dialects import hw -from pycde import Input, Output, System, generator, module, types +from pycde import Input, Output, System, generator, Module, types from pycde.module import import_hw_module -mlir_module = Module.parse(""" +mlir_module = IrModule.parse(""" hw.module @add(%a: i1, %b: i1) -> (out: i1) { %0 = comb.add %a, %b : i1 hw.output %0 : i1 @@ -25,8 +25,7 @@ imported_modules.append(imported_module) -@module -class Top: +class Top(Module): a = Input(types.i1) b = Input(types.i1) out0 = Output(types.i1) @@ -43,7 +42,6 @@ def generate(ports): system = System([Top]) -system.import_modules(imported_modules) system.generate() # CHECK: msft.module @Top {} (%a: i1, %b: i1) -> (out0: i1, out1: i1) diff --git a/frontends/PyCDE/test/test_misc_errors.py b/frontends/PyCDE/test/test_misc_errors.py index 9d87d65c7b09..b169a6d9ef04 100644 --- a/frontends/PyCDE/test/test_misc_errors.py +++ b/frontends/PyCDE/test/test_misc_errors.py @@ -1,12 +1,12 @@ # RUN: %PYTHON% py-split-input-file.py %s | FileCheck %s -from pycde import Input, generator, dim +from pycde import Input, generator, dim, Module from pycde.constructs import Mux from pycde.testing import unittestmodule @unittestmodule() -class Mux1: +class Mux1(Module): In = Input(dim(3, 4, 5)) Sel = Input(dim(8)) @@ -21,7 +21,7 @@ def create(ports): @unittestmodule() -class Mux2: +class Mux2(Module): Sel = Input(dim(8)) diff --git a/frontends/PyCDE/test/test_ndarray.py b/frontends/PyCDE/test/test_ndarray.py index c4089b947412..e80349b44ac0 100644 --- a/frontends/PyCDE/test/test_ndarray.py +++ b/frontends/PyCDE/test/test_ndarray.py @@ -1,12 +1,11 @@ # RUN: %PYTHON% py-split-input-file.py %s | FileCheck %s -from pycde import Input, Output, generator +from pycde import Input, Output, generator, Module from pycde.testing import unittestmodule from pycde.ndarray import NDArray from pycde.dialects import hw from pycde.pycde_types import types, dim -from pycde.value import ListValue # ndarray transposition via injected ndarray on a ListValue # Putting this as first test in case users use this file as a reference. @@ -15,7 +14,7 @@ @unittestmodule() -class M1: +class M1(Module): in1 = Input(dim(types.i32, 4, 8)) out = Output(dim(types.i32, 8, 4)) @@ -30,7 +29,7 @@ def build(ports): @unittestmodule() -class M1: +class M1(Module): in1 = Input(dim(types.i32, 4, 8)) out = Output(dim(types.i32, 2, 16)) @@ -43,7 +42,7 @@ def build(ports): @unittestmodule() -class M2: +class M2(Module): in0 = Input(dim(types.i32, 16)) in1 = Input(types.i32) t_c = dim(types.i32, 8, 4) @@ -62,7 +61,7 @@ def build(ports): # ----- @unittestmodule() -class M5: +class M5(Module): in0 = Input(dim(types.i32, 16)) in1 = Input(types.i32) t_c = dim(types.i32, 16) @@ -100,7 +99,7 @@ def build(ports): @unittestmodule() -class M1: +class M1(Module): in1 = Input(dim(types.i32, 10, 10)) out = Output(dim(types.i32, 10, 10)) @@ -116,7 +115,7 @@ def build(ports): @unittestmodule() -class M1: +class M1(Module): in1 = Input(dim(types.i32, 10)) out = Output(dim(types.i32, 10)) @@ -132,7 +131,7 @@ def build(ports): @unittestmodule() -class M1: +class M1(Module): in1 = Input(dim(types.i32, 10)) in2 = Input(dim(types.i32, 10)) in3 = Input(dim(types.i32, 10)) @@ -153,7 +152,7 @@ def build(ports): @unittestmodule() -class M1: +class M1(Module): in1 = Input(dim(types.i32, 10)) out = Output(dim(types.i32, 10)) @@ -186,7 +185,7 @@ def build(ports): @unittestmodule() -class M1: +class M1(Module): out = Output(dim(types.i32, 3, 3)) @generator diff --git a/frontends/PyCDE/test/test_ndarray_errors.py b/frontends/PyCDE/test/test_ndarray_errors.py index a9b0724f17cc..9a5fc312132d 100644 --- a/frontends/PyCDE/test/test_ndarray_errors.py +++ b/frontends/PyCDE/test/test_ndarray_errors.py @@ -1,6 +1,6 @@ # RUN: %PYTHON% py-split-input-file.py %s | FileCheck %s -from pycde import System, Input, generator +from pycde import Module, Input, generator from pycde.testing import unittestmodule from pycde.pycde_types import types, dim from pycde.ndarray import NDArray @@ -9,7 +9,7 @@ @unittestmodule() -class M1: +class M1(Module): in1 = Input(types.i32) @generator @@ -28,7 +28,7 @@ def build(ports): @unittestmodule() -class M1: +class M1(Module): in1 = Input(types.i33) @generator @@ -44,7 +44,7 @@ def build(ports): @unittestmodule() -class M1: +class M1(Module): in1 = Input(dim(types.i32, 10)) @generator @@ -59,7 +59,7 @@ def build(ports): @unittestmodule() -class M1: +class M1(Module): in1 = Input(types.i31) @generator diff --git a/frontends/PyCDE/test/test_slicing.py b/frontends/PyCDE/test/test_slicing.py index 1b57650ba6c7..e497d4e9d481 100644 --- a/frontends/PyCDE/test/test_slicing.py +++ b/frontends/PyCDE/test/test_slicing.py @@ -1,6 +1,6 @@ # RUN: %PYTHON% py-split-input-file.py %s | FileCheck %s -from pycde import System, Input, Output, module, generator +from pycde import System, Input, Output, generator, Module from pycde.pycde_types import dim # CHECK-LABEL: msft.module @MyMod {} (%in_port: i8) -> (out0: i5, out1: i5) attributes {fileName = "MyMod.sv"} { @@ -10,8 +10,7 @@ # CHECK: } -@module -class MyMod: +class MyMod(Module): in_port = Input(dim(8)) out0 = Output(dim(5)) out1 = Output(dim(5)) diff --git a/frontends/PyCDE/test/verilog_readablility.py b/frontends/PyCDE/test/verilog_readablility.py index a7f6e9ca03f3..fc077bc8a5d1 100644 --- a/frontends/PyCDE/test/verilog_readablility.py +++ b/frontends/PyCDE/test/verilog_readablility.py @@ -1,10 +1,9 @@ # RUN: %PYTHON% %s | FileCheck %s -from pycde import (Output, Input, module, generator, types, dim, System) +from pycde import (Output, Input, Module, generator, types, dim, System) -@module -class WireNames: +class WireNames(Module): clk = Input(types.i1) data_in = Input(dim(32, 3)) sel = Input(types.i2) @@ -13,14 +12,12 @@ class WireNames: b = Output(types.i32) @generator - def build(ports): - foo = ports.data_in[0] + def build(self): + foo = self.data_in[0] foo.name = "foo" arr_data = dim(32, 4)([1, 2, 3, 4], "arr_data") - ports.set_all_ports({ - 'a': foo.reg(ports.clk).reg(ports.clk), - 'b': arr_data[ports.sel], - }) + self.a = foo.reg(self.clk).reg(self.clk) + self.b = arr_data[self.sel] sys = System([WireNames])