Skip to content

Commit

Permalink
Improve logic for MTSplit (#66)
Browse files Browse the repository at this point in the history
* Add test to check for non anonymous kernels and MTSplit and MTConcat

* Do not trace application again if args are the same

* Fix call

* Remove unused test

* Add testcases with float numbers

* Add memtile reduce test

* Flake8

* Fix source name

* Fix name

* Fix behavior

* Add more comprehensive tests

* Flake8

* Increase tolerance

* Set random seed

* Raise error if datatype not supported for mt links

* Remove trailing whitespace and some flake8
  • Loading branch information
mariodruiz authored Oct 15, 2024
1 parent 3860084 commit 6338ba4
Show file tree
Hide file tree
Showing 5 changed files with 340 additions and 76 deletions.
27 changes: 17 additions & 10 deletions npu/build/appbuilder.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def __init__(self, name=None) -> None:
self.connections = None

self.previous_build_args = None
self._metadata = None

def __call__(self, *args):
""" Calling the class will execute the callgraph directly."""
Expand All @@ -64,17 +65,21 @@ def __call__(self, *args):

def callgraph(self):
""" This method should be overridden by a subclass. """
raise NotImplementedError(f'Subclass needs to implement the callgraph function for use in tracing and behavioral execution')
raise NotImplementedError('Subclass needs to implement the callgraph '
'function for use in tracing and behavioral'
'execution')

def to_metadata(self, *args):
""" The application is converted into the AppMetadata after tracing the callgraph() call."""
self.previous_build_args = args
self.kernels, self.connections = self.fxtracer.to_trace(*args)
if self._metadata is None or self.previous_build_args != args:
self.kernels, self.connections = self.fxtracer.to_trace(*args)

return AppMetada(self.name,
self.unique_named(self.kernels),
self.unique_named(self.connections),
self.to_sequence())
self._metadata= AppMetada(self.name,
self.unique_named(self.kernels),
self.unique_named(self.connections),
self.to_sequence())
self.previous_build_args = args
return self._metadata

def to_handoff(self, *args, file=None):
""" Converts the application into a serializable JSON file."""
Expand All @@ -89,7 +94,7 @@ def to_json(self, *args):

@property
def metadata(self, *args):
""" Generates the application JSON and displays inside a IPython environment."""
"""Generates the application JSON and displays inside a IPython environment"""
from npu import ReprDict
self.validate_previous_build_args()
return ReprDict(self.to_json(*self.previous_build_args), rootname=self.name)
Expand Down Expand Up @@ -147,8 +152,10 @@ def __add__(self, app_component):

def validate_previous_build_args(self):
if self.previous_build_args is None:
raise ValueError(f'Before using this AppBuilder API, please first call the AppBuilder instance directly or call \
to_metadata(), to_json() or to_build() with callgraph args to complete the application graph')
raise ValueError('Before using this AppBuilder API, please first '
'call the AppBuilder instance directly or call '
'to_metadata(), to_json() or to_build() with '
'callgraph args to complete the application graph')

def merge_applications(self, newkernels, newconnections):
self.connections.extend(newconnections)
Expand Down
67 changes: 32 additions & 35 deletions npu/build/mlirbuilder.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,9 @@
from .mlirtiles import CTTile, MEMTile, ITTile
from .buffers import MTBuffer
from .mlirconnections import MLIRConnect, ObjectFIFO
from .mlirsequencebuilder import MLIRSequnceBuilder
from .mlirsequencebuilder import MLIRSequnceBuilder
from typing import Dict, List, Tuple
from collections import OrderedDict
from itertools import groupby
import ctypes


class MLIRBuilder:
Expand All @@ -17,7 +15,7 @@ class MLIRBuilder:
Attributes
----------
metadata : AppMetadata
The application metadata.
The application metadata.
app : JSON
The json representation of an application metadata.
kernels : dict
Expand All @@ -31,7 +29,7 @@ class MLIRBuilder:
"""

def __init__(self, metadata, config=(4,1,1)):
"""Return a new MLIRBuilder object."""
"""Return a new MLIRBuilder object."""
app = metadata.to_json()
self.kernels = app['kernels']
self.connections = app['connections']
Expand All @@ -43,7 +41,7 @@ def __init__(self, metadata, config=(4,1,1)):
MLIRConnect.reset_id()

self.aietiles, self.memtiles, self.sdmatiles = self._parse_tiles(config)
self.tiles = {**self.sdmatiles, **self.memtiles, **self.aietiles}
self.tiles = {**self.sdmatiles, **self.memtiles, **self.aietiles}
self._map_kernels_to_tiles()

self._cons_src2dst = self._populate_src2dst_cons_dict()
Expand All @@ -63,12 +61,11 @@ def _parse_tiles(self, config):
sdmas = {(ix, 0) : ITTile(ix, 0) for ix in range(config[1])}
return aies, mts, sdmas


def to_mlir(self, file=None):
"""Toplevel method to generate the application MLIR."""
indent = " "
used_tiles = [tile for _, tile in self.tiles.items() if tile.is_used()]
used_aie_tiles = [tile for _, tile in self.aietiles.items() if tile.is_used()]
used_aie_tiles = [tile for _, tile in self.aietiles.items() if tile.is_used()]

s = 'module {\n'
s += f'{indent}AIE.device(ipu)'
Expand All @@ -91,7 +88,7 @@ def to_mlir(self, file=None):

for _, aie in self.aietiles.items():
s += aie.to_mlir()

s += f'{self.seqbuilder.mlir}'
s += ' }\n'
s += "}\n"
Expand All @@ -102,8 +99,7 @@ def to_mlir(self, file=None):
with open(file, "w") as f:
f.write(s)

return ""

return ""

def _map_kernels_to_tiles(self):

Expand All @@ -113,7 +109,7 @@ def _map_kernels_to_tiles(self):

if self.aietiles[k['tloc']].kernel is not None:
raise ValueError(f'Cannot place {k["name"]} kernel - CT tile previously constrained with {self.aietiles[k["tloc"]].kernel["name"]}')
self.aietiles[k['tloc']].kernel = k
self.aietiles[k['tloc']].kernel = k

# Then AIE kernels onto any free AIE Tiles
# ...place all buffers on first MT / SDMA tiles
Expand All @@ -127,7 +123,7 @@ def _map_kernels_to_tiles(self):
for _, mt in self.memtiles.items():
k['tloc'] = mt.tloc
mt.buffers[kname] = k
break
break
elif k['type'] == 'IT':
for _, sdma in self.sdmatiles.items():
k['tloc'] = sdma.tloc
Expand All @@ -148,7 +144,7 @@ def _populate_src2dst_cons_dict(self) -> Dict[Tuple[str,str], List[Tuple[str,str
else:
cons_dict[con_src] = [con_dst]
return cons_dict

def _populate_dst2src_cons_dict(self) -> Dict[Tuple[str,str], List[Tuple[str,str]]]:
""" Creates a mapping dict of the connections where the key is the dst (kernel, port) tuple and
the value is a list of source (kernel,port) tuples."""
Expand All @@ -160,8 +156,8 @@ def _populate_dst2src_cons_dict(self) -> Dict[Tuple[str,str], List[Tuple[str,str
cons_dict[con_dst].append(con_src)
else:
cons_dict[con_dst] = [con_src]
return cons_dict
return cons_dict

def _populate_broadcast_dict(self) -> Dict[Tuple[str,str], List[Tuple[str,str]]]:
""" Filters the dict of src2dst connection mappings produced by _populate_src2dst_cons_dict
to produce a dict that only contains the broadcast pattern where the same data is going
Expand All @@ -176,7 +172,7 @@ def _populate_broadcast_dict(self) -> Dict[Tuple[str,str], List[Tuple[str,str]]]
broadcasts[src] = dsts
elif bc_analysis == "MIX":
raise RuntimeError(f"""
Mixing broadcasts and distributes from the same source is not yet supported
Mixing broadcasts and distributes from the same source is not yet supported
{src=} to {dsts=}
""")
return broadcasts
Expand All @@ -188,36 +184,34 @@ def _all_equal(self,outgoing_shapes)->bool:
return False
return True


def _all_unique(self, outgoing_shapes)->bool:
def _all_unique(self, outgoing_shapes)->bool:
seen_list = []
for lst in outgoing_shapes.values():
if lst in seen_list:
return False
return True

def _analyse_broadcast(self, src:Tuple[str,str]) -> str:
"""When given a source (kernel, port) tuple determines if it is a:
"""When given a source (kernel, port) tuple determines if it is a:
true broadcast, i.e. same data unique destinations (returns "BCAST");
distribute op, i.e. different chunks of the data to different destinations (returns "DIST");
a mix of both (returns "MIX").
"""
outgoing_shapes = {}
outgoing_shapes = {}
for s in self.sequence:
con_src = (s['srckernelname'], s['srcportname'])
if con_src == src:
dst = (s['snkkernelname'], s['snkportname'])
outgoing_shape = (s['srcslices'], s['srcoffset'], s['srcnbytes'])
if dst not in outgoing_shapes:
outgoing_shapes[dst] = []
outgoing_shapes[dst] = []
outgoing_shapes[dst].append(outgoing_shape)
if self._all_equal(outgoing_shapes):
return "BCAST"
elif self._all_unique(outgoing_shapes):
elif self._all_unique(outgoing_shapes):
return "DIST"
else:
return "MIX"


def _get_bcast_nbytes_offset(self, bcast_src:Tuple[str,str])->Tuple[int,int]:
for s in self.sequence:
Expand All @@ -229,11 +223,11 @@ def _get_bcast_nbytes_offset(self, bcast_src:Tuple[str,str])->Tuple[int,int]:
def _map_connections_to_objectfifos(self):
obfs = list()
for s in [s for s in self.sequence if s['seqtype'] == 'buffer']:
for c in [c for c in self.connections if s['name'] == c]:
for c in [c for c in self.connections if s['name'] == c]:
if c in [obf.name for obf in obfs]:
# TODO : validate that channel transfer nbytes is consistent
break
self.connections[c]['ctype'] = "objfifo,pingpong"
self.connections[c]['ctype'] = "objfifo,pingpong"

src = (self.connections[c]['srckernel'],
self.connections[c]['srcport'])
Expand All @@ -243,11 +237,11 @@ def _map_connections_to_objectfifos(self):
self.connections[c]['sinkport'])]

obfs.append(ObjectFIFO(c, src, dsts, s['nbytes'], s['offset'], self.tiles, self.kernels))

# map broadcast connections
for src, dsts in self._cons_broadcasts.items():
nbytes, offset = self._get_bcast_nbytes_offset(src)
obfs.append(ObjectFIFO(f"{src[0]}__{src[1]}", src, dsts, nbytes, offset, self.tiles, self.kernels))
obfs.append(ObjectFIFO(f"{src[0]}__{src[1]}", src, dsts, nbytes, offset, self.tiles, self.kernels))

return obfs

Expand Down Expand Up @@ -277,7 +271,7 @@ def _get_mtbuffer_io(self)->Dict:
elif isinstance(c.sinkkernel, MTBuffer):
if c.sinkkernel.name not in mt_buff_links:
mt_buff_links[c.sinkkernel.name] = { 'in' : [], 'out' : [] }
mt_buff_links[c.sinkkernel.name]['in'].append(c)
mt_buff_links[c.sinkkernel.name]['in'].append(c)
mt_buff_links[c.sinkkernel.name]['in'] = self._sort_sinkport_mtbuff_links(mt_buff_links[c.sinkkernel.name]['in'])
return mt_buff_links

Expand All @@ -291,8 +285,10 @@ def sorting_key(connection):
return connection.srcport.slices[0]
if isinstance(connection.srcport.slices[0], slice):
return connection.srcport.slices[0].start
raise ValueError(f'{type(connection.srcport.slices[0])} not '
'supported on MTBuffer source port')
return sorted(buff_links, key=sorting_key)

def _sort_sinkport_mtbuff_links(self, buff_links)->List:
def sorting_key(connection):
if not connection.sinkport.slices:
Expand All @@ -303,14 +299,16 @@ def sorting_key(connection):
return connection.sinkport.slices[0]
if isinstance(connection.sinkport.slices[0], slice):
return connection.sinkport.slices[0].start
raise ValueError(f'{type(connection.sinkport.slices[0])} not '
'supported on MTBuffer sink port')
return sorted(buff_links, key=sorting_key)

def _is_con_broadcast(self, connection)->bool:
return (connection.srckernel.name, connection.srcport.name) in self._cons_broadcasts

def _get_objectfifo_varname(self, connection)->str:
if self._is_con_broadcast(connection):
name = f"{connection.srckernel.name}__{connection.srcport.name}"
name = f"{connection.srckernel.name}__{connection.srcport.name}"
else:
name = connection.name

Expand All @@ -333,7 +331,7 @@ def _link_objectfifos_via_memtile(self, indent='')->str:
if not self._is_mtlink_bcast(link):
s += self._generate_distribute_link(link, indent=indent)
else:
s += self._generate_broadcast_link(link, mt_buff_links, indent=indent)
s += self._generate_broadcast_link(link, mt_buff_links, indent=indent)

return s

Expand Down Expand Up @@ -370,11 +368,10 @@ def _generate_broadcast_link(self, link, mtlinks, indent='')->str:
return s
raise RuntimeError(f"Unable to find a feeding buffer to link to in the memtile for {link=}")


def _validate_app(self):

if len(self.kernels) == 0 or len(self.connections) == 0:
raise ValueError(f'{len(self.kernels)} kernels or {len(self.connections)} connections cannot be zero')
raise ValueError(f'{len(self.kernels)} kernels or {len(self.connections)} connections cannot be zero')

if len(self.aietiles) > len(self.aietiles):
raise ValueError(f'{len(self.kernels)} kernels cannot be placed onto {len(self.aietiles)} AIE tiles')
15 changes: 15 additions & 0 deletions tests/test_applications.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,21 @@ def callgraph(self, x, n):
return x


class MtSplitConcat4AIEsNonAnonymousPlusN(AppBuilder):
def __init__(self):
super().__init__()
self.kernels = [PlusN() for _ in range(4)]
self.mtbsplit = MTSplit(4)
self.mtbconcat = MTConcat()

def callgraph(self, x_in, x_out, n):
new_xs = []
xs = self.mtbsplit(x_in)
for i in range(4):
new_xs.append(self.kernels[i](xs[i], xs[i].nbytes, n))
x_out[:] = self.mtbconcat(new_xs)


class TwoInputsApp(AppBuilder):
def callgraph(self, k, a, b, size):
"""Callgraph that tests a kernel with two input arguments """
Expand Down
Loading

0 comments on commit 6338ba4

Please sign in to comment.