Skip to content

Commit

Permalink
[PyCDE] Change module API to extend Module instead of using a decor…
Browse files Browse the repository at this point in the history
…ator (#4556)

Refactor PyCDE module creation to a subclass based scheme. This is a breaking API change, but a largely mechanical update to users code. Reasons for the refactor:
- It wasn't possible for users to extend their modules by just subclassing them.
- Code completion for ports inside of generators didn't work.
- We are going to introduce functional models, wherein a method of a model will be interacting with the instance. We wanted hardware generation to have a similar feeling.
- Internally, supporting more module like constructs (e.g. ESI service implementations) involved specifying a bunch of callbacks. This replaces that with an OOP approach.
- The old code was largely ad-hoc and evolved to meet our needs, and badly in need of some TLC. The new code is far better structured (and documented).
  • Loading branch information
teqdruid authored Jan 20, 2023
1 parent b6d8508 commit 1e680d1
Show file tree
Hide file tree
Showing 36 changed files with 744 additions and 857 deletions.
8 changes: 3 additions & 5 deletions frontends/PyCDE/integration_test/esi_ram.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
# PY: run_cosim(tmpdir, rpcschemapath, simhostport)

import pycde
from pycde import (Clock, Input, module, generator, types)
from pycde import (Clock, Input, Module, generator, types)
from pycde.constructs import Wire
from pycde import esi

Expand All @@ -24,8 +24,7 @@ class MemComms:
to_client_type=WriteType)


@module
class Mid:
class Mid(Module):
clk = Clock(types.i1)
rst = Input(types.i1)

Expand All @@ -40,8 +39,7 @@ def construct(ports):
RamI64x8.write(write_data)


@module
class top:
class top(Module):
clk = Clock(types.i1)
rst = Input(types.i1)

Expand Down
18 changes: 6 additions & 12 deletions frontends/PyCDE/integration_test/esi_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,10 @@
# PY: run_cosim(tmpdir, rpcschemapath, simhostport)

import pycde
from pycde import (Clock, Input, InputChannel, OutputChannel, module, generator,
from pycde import (Clock, Input, InputChannel, OutputChannel, Module, generator,
types)
from pycde import esi
from pycde.constructs import Wire
from pycde.dialects import comb

import sys

Expand All @@ -23,8 +22,7 @@ class HostComms:
to_client_type=types.i32)


@module
class Producer:
class Producer(Module):
clk = Input(types.i1)
int_out = OutputChannel(types.i32)

Expand All @@ -34,8 +32,7 @@ def construct(ports):
ports.int_out = chan


@module
class Consumer:
class Consumer(Module):
clk = Input(types.i1)
int_in = InputChannel(types.i32)

Expand All @@ -44,8 +41,7 @@ def construct(ports):
HostComms.to_host(ports.int_in, "loopback_out")


@module
class LoopbackInOutAdd7:
class LoopbackInOutAdd7(Module):

@generator
def construct(ports):
Expand All @@ -59,8 +55,7 @@ def construct(ports):
loopback.assign(data_chan)


@module
class Mid:
class Mid(Module):
clk = Clock(types.i1)
rst = Input(types.i1)

Expand All @@ -72,8 +67,7 @@ def construct(ports):
LoopbackInOutAdd7()


@module
class Top:
class Top(Module):
clk = Clock(types.i1)
rst = Input(types.i1)

Expand Down
5 changes: 2 additions & 3 deletions frontends/PyCDE/integration_test/pytorch/dot_prod_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
# PY: from dot_prod_system import run_cosim
# PY: run_cosim(tmpdir, rpcschemapath, simhostport)

from pycde import Input, module, generator, types
from pycde import Input, Module, generator, types
from pycde.common import Clock
from pycde.system import System
from pycde.esi import FromServer, ToFromServer, ServiceDecl, CosimBSP
Expand Down Expand Up @@ -41,8 +41,7 @@
# output_type="linalg-on-tensors")


@module
class Gasket:
class Gasket(Module):
"""Wrap the accelerator IP module. Instantiate the requiste memories. Wire the
memories to the host and the host to the module control signals."""

Expand Down
12 changes: 6 additions & 6 deletions frontends/PyCDE/src/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,6 @@
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

from .common import (AppID, Clock, Input, InputChannel, Output, OutputChannel)
from .module import (externmodule, generator, module, no_connect)
from .system import (System)
from .pycde_types import (dim, types)
from .value import (Value)

from .circt import ir
from . import circt
import atexit
Expand All @@ -24,6 +18,12 @@ def __exit_ctxt():
DefaultContext.__exit__(None, None, None)


from .common import (AppID, Clock, Input, InputChannel, Output, OutputChannel)
from .module import (generator, modparams, Module)
from .system import (System)
from .pycde_types import (dim, types)
from .value import (Value)

# Until we get source location based on Python stack traces, default to unknown
# locations.
DefaultLocation = ir.Location.unknown()
Expand Down
4 changes: 4 additions & 0 deletions frontends/PyCDE/src/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,3 +88,7 @@ class _PyProxy:

def __init__(self, name: str):
self.name = name


class PortError(Exception):
pass
6 changes: 3 additions & 3 deletions frontends/PyCDE/src/constructs.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from .pycde_types import PyCDEType, dim, types
from .value import BitsSignal, BitVectorSignal, ListValue, Value, Signal
from .value import get_slice_bounds
from .module import generator, module, _BlockContext
from .module import generator, modparams, Module, _BlockContext
from .circt.support import get_value, BackedgeBuilder
from .circt.dialects import msft, hw, sv
from pycde.dialects import comb
Expand Down Expand Up @@ -149,10 +149,10 @@ def ControlReg(clk: Signal, rst: Signal, asserts: List[Signal],
opposite. If both an assert and a reset are active on the same cycle, the
assert takes priority."""

@module
@modparams
def ControlReg(num_asserts: int, num_resets: int):

class ControlReg:
class ControlReg(Module):
clk = Clock()
rst = Input(types.i1)
out = Output(types.i1)
Expand Down
151 changes: 75 additions & 76 deletions frontends/PyCDE/src/esi.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,9 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

from pycde.system import System
from .module import (Generator, _module_base, _BlockContext,
_GeneratorPortAccess, _SpecializedModule)
from .module import Generator, _BlockContext, Module, ModuleLikeBuilderBase
from pycde.value import ChannelValue, ClockSignal, Signal, Value
from .common import AppID, Input, Output, InputChannel, OutputChannel, _PyProxy
from .common import Input, Output, InputChannel, OutputChannel, _PyProxy
from .circt.dialects import esi as raw_esi, hw, msft
from .circt.support import BackedgeBuilder
from pycde.pycde_types import ChannelType, ClockType, PyCDEType, types
Expand All @@ -15,7 +14,7 @@

from pathlib import Path
import shutil
from typing import List, Optional, Type
from typing import Dict, List, Optional, Type

__dir__ = Path(__file__).parent

Expand Down Expand Up @@ -180,11 +179,10 @@ def Cosim(decl: ServiceDecl, clk, rst):
def CosimBSP(user_module):
"""Wrap and return a cosimulation 'board support package' containing
'user_module'"""
from .module import module, generator
from .module import Module, generator
from .common import Clock, Input

@module
class top:
class top(Module):
clk = Clock()
rst = Input(types.int(1))

Expand Down Expand Up @@ -256,8 +254,7 @@ class _ServiceGeneratorChannels:
"""Provide access to the channels which the service generator is responsible
for connecting up."""

def __init__(self, mod: _SpecializedModule,
req: raw_esi.ServiceImplementReqOp):
def __init__(self, mod: Module, req: raw_esi.ServiceImplementReqOp):
self._req = req
portReqsBlock = req.portReqs.blocks[0]

Expand All @@ -269,7 +266,7 @@ def __init__(self, mod: _SpecializedModule,
]

# Find the output channel requests and store the settable proxies.
num_output_ports = len(mod.output_port_lookup)
num_output_ports = len(mod.outputs)
self._output_reqs = [
_OutputChannelSetter(req, self._req.results[num_output_ports + idx])
for idx, req in enumerate(portReqsBlock)
Expand All @@ -294,75 +291,76 @@ def check_unconnected_outputs(self):
raise ValueError(f"{name_str} has not been connected.")


def ServiceImplementation(decl: Optional[ServiceDecl]):
class ServiceImplementationModuleBuilder(ModuleLikeBuilderBase):
"""Define how to build ESI service implementations. Unlike Modules, there is
no distinction between definition and instance -- ESI service providers are
built where they are instantiated."""

def instantiate(self, impl, instance_name: str, **inputs):
# Each instantiation of the ServiceImplementation has its own
# registration.
opts = _service_generator_registry.register(impl)

# Create the op.
decl_sym = None
if impl.decl is not None:
decl_sym = ir.FlatSymbolRefAttr.get(impl.decl._materialize_service_decl())
return raw_esi.ServiceInstanceOp(
result=[t for _, t in self.outputs],
service_symbol=decl_sym,
impl_type=_ServiceGeneratorRegistry._impl_type_name,
inputs=[inputs[pn].value for pn, _ in self.inputs],
impl_opts=opts,
loc=self.loc)

def generate_svc_impl(self, serviceReq: raw_esi.ServiceInstanceOp):
""""Generate the service inline and replace the `ServiceInstanceOp` which is
being implemented."""

assert len(self.generators) == 1
generator: Generator = list(self.generators.values())[0]
ports = self.generator_port_proxy(serviceReq.operation.operands, self)
with self.GeneratorCtxt(self, ports, serviceReq, generator.loc):

# Run the generator.
channels = _ServiceGeneratorChannels(self, serviceReq)
rc = generator.gen_func(ports, channels=channels)
if rc is None:
rc = True
elif not isinstance(rc, bool):
raise ValueError("Generators must a return a bool or None")
ports._check_unconnected_outputs()
channels.check_unconnected_outputs()

# Replace the output values from the service implement request op with
# the generated values. Erase the service implement request op.
for idx, port_value in enumerate(ports._output_values):
msft.replaceAllUsesWith(serviceReq.operation.results[idx],
port_value.value)
serviceReq.operation.erase()

return rc


class ServiceImplementation(Module):
"""A generator for a service implementation. Must contain a @generator method
which will be called whenever required to implement the server. Said generator
function will be called with the same 'ports' argument as modules and a
'channels' argument containing lists of the input and output channels which
need to be connected to the service being implemented."""

def wrap(service_impl, decl: Optional[ServiceDecl] = decl):

def instantiate_cb(mod: _SpecializedModule, instance_name: str,
inputs: dict, appid: AppID, loc):
# Each instantiation of the ServiceImplementation has its own
# registration.
opts = _service_generator_registry.register(mod)
decl_sym = None
if decl is not None:
decl_sym = ir.FlatSymbolRefAttr.get(decl._materialize_service_decl())
return raw_esi.ServiceInstanceOp(
result=[t for _, t in mod.output_ports],
service_symbol=decl_sym,
impl_type=_ServiceGeneratorRegistry._impl_type_name,
inputs=[inputs[pn].value for pn, _ in mod.input_ports],
impl_opts=opts,
loc=loc)

def generate(generator: Generator, spec_mod: _SpecializedModule,
serviceReq: raw_esi.ServiceInstanceOp):
arguments = serviceReq.operation.operands
with ir.InsertionPoint(
serviceReq), generator.loc, BackedgeBuilder(), _BlockContext():
# Insert generated code after the service instance op.
ports = _GeneratorPortAccess(spec_mod, arguments)

# Enter clock block implicitly if only one clock given
clk = None
if len(spec_mod.clock_ports) == 1:
clk_port = list(spec_mod.clock_ports.values())[0]
clk = ClockSignal(arguments[clk_port], ClockType())
clk.__enter__()

# Run the generator.
channels = _ServiceGeneratorChannels(spec_mod, serviceReq)
rc = generator.gen_func(ports, channels=channels)
if rc is None:
rc = True
elif not isinstance(rc, bool):
raise ValueError("Generators must a return a bool or None")
ports.check_unconnected_outputs()
channels.check_unconnected_outputs()

# Replace the output values from the service implement request op with
# the generated values. Erase the service implement request op.
for port_name, port_value in ports._output_values.items():
port_num = spec_mod.output_port_lookup[port_name]
msft.replaceAllUsesWith(serviceReq.operation.results[port_num],
port_value.value)
serviceReq.operation.erase()

if clk is not None:
clk.__exit__(None, None, None)

return rc

return _module_base(service_impl,
extern_name=None,
generator_cb=generate,
instantiate_cb=instantiate_cb)

return wrap
BuilderType = ServiceImplementationModuleBuilder

def __init__(self, decl: Optional[ServiceDecl], **inputs):
"""Instantiate a service provider for service declaration 'decl'. If decl,
implementation is expected to handle any and all service declarations."""

self.decl = decl
super().__init__(**inputs)

@property
def name(self):
return self.__class__.__name__


class _ServiceGeneratorRegistry:
Expand All @@ -372,7 +370,7 @@ class _ServiceGeneratorRegistry:
_impl_type_name = ir.StringAttr.get("pycde")

def __init__(self):
self._registry = {}
self._registry: Dict[str, ServiceImplementation] = {}

# Register myself with ESI so I can dispatch to my internal registry.
assert _ServiceGeneratorRegistry._registered is False, \
Expand All @@ -382,7 +380,8 @@ def __init__(self):
self._implement_service)
_ServiceGeneratorRegistry._registered = True

def register(self, service_implementation: _SpecializedModule) -> ir.DictAttr:
def register(self,
service_implementation: ServiceImplementation) -> ir.DictAttr:
"""Register a ServiceImplementation generator with the PyCDE generator.
Called when the ServiceImplamentation is defined."""

Expand All @@ -407,7 +406,7 @@ def _implement_service(self, req: ir.Operation):
return False
(impl, sys) = self._registry[impl_name]
with sys:
return impl.generate(serviceReq=req.opview)
return impl._builder.generate_svc_impl(serviceReq=req.opview)


_service_generator_registry = _ServiceGeneratorRegistry()
Expand Down
Loading

0 comments on commit 1e680d1

Please sign in to comment.