Skip to content

Commit ffe0dee

Browse files
add Topology.elementwise_stack
This patch adds `Topology.elementwise_stack` which stacks a function bound to every element of a topology.
1 parent 4864766 commit ffe0dee

File tree

1 file changed

+107
-0
lines changed

1 file changed

+107
-0
lines changed

nutils/topology.py

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -360,6 +360,77 @@ def sample(self, ischeme: str, degree: int) -> Sample:
360360

361361
raise NotImplementedError
362362

363+
def elementwise_stack(self, func: function.Array, *, axis: int = 0) -> function.Array:
364+
'''Return the stack of the function bound to every element of this topology.
365+
366+
The order of the stack matches the order of the elements of this
367+
topology. Therefor, for any :class:`Topology` ``topo``
368+
369+
topo.elementwise_stack(topo.f_index)
370+
371+
is guaranteed to be the same (after evaluation) as
372+
373+
numpy.arange(len(topo))
374+
375+
Parameters
376+
----------
377+
func : :class:`nutils.function.Array`
378+
The function to bind to elements and stack.
379+
axis : :class:`int`
380+
The axis in the result array along which the elements are stacked.
381+
382+
Returns
383+
-------
384+
:class:`nutils.function.Array`
385+
The stack of function at the first element, the second, etc. until
386+
the last.
387+
388+
Examples
389+
--------
390+
391+
Given a ``base`` :class:`Topology` and a ``refinement``, with
392+
:meth:`Topology.elementwise_stack` we can obtain the element index of
393+
the ``base`` for each element in the ``refinement``:
394+
395+
>>> from nutils import function, mesh
396+
>>> base, geom = mesh.line(4)
397+
>>> refinement = base.refined_by([1, 3]).refined_by([5])
398+
>>> function.eval(refinement.elementwise_stack(base.f_index))
399+
array([0, 2, 1, 1, 3, 3, 3])
400+
401+
Note that refined elements are positioned at the end of the refined
402+
topology.
403+
404+
Notes
405+
-----
406+
407+
This method delegates the work to
408+
:meth:`Topology.elementwise_stack_last_axis`. Implementations of
409+
:class:`Topology` should implement the latter.
410+
'''
411+
412+
stack = self.elementwise_stack_last_axis(func)
413+
axis = numeric.normdim(stack.ndim, axis)
414+
if axis != stack.ndim - 1:
415+
stack = numpy.transpose(stack, (*range(axis), *range(axis + 1, stack.ndim), axis))
416+
return stack
417+
418+
def elementwise_stack_last_axis(self, func: function.Array) -> function.Array:
419+
'''Return the stack along the last axis of the function bound to every element of this topology.
420+
421+
Parameters
422+
----------
423+
func : :class:`nutils.function.Array`
424+
The function to bind to elements and stack.
425+
426+
Returns
427+
-------
428+
:class:`nutils.function.Array`
429+
The stack of function at element 0, element 1, etc along the last axis.
430+
'''
431+
432+
raise NotImplementedError
433+
363434
@single_or_multiple
364435
def integrate_elementwise(self, funcs: Iterable[function.Array], *, degree: int, asfunction: bool = False, ischeme: str = 'gauss', arguments: Optional[_ArgDict] = None) -> Union[List[numpy.ndarray], List[function.Array]]:
365436
'element-wise integration'
@@ -1042,6 +1113,10 @@ def basis_std(self, degree: int, *args, **kwargs) -> function.Array:
10421113
def sample(self, ischeme: str, degree: int) -> Sample:
10431114
return Sample.empty(self.spaces, self.ndims)
10441115

1116+
def elementwise_stack_last_axis(self, func: function.Array) -> function.Array:
1117+
func = function.Array.cast(func)
1118+
return function.zeros((*func.shape, 0), dtype=func.dtype)
1119+
10451120

10461121
class _DisjointUnion(_TensorialTopology):
10471122

@@ -1096,6 +1171,9 @@ def integrate_elementwise(self, funcs: Iterable[function.Array], *, degree: int,
10961171
def sample(self, ischeme: str, degree: int) -> Sample:
10971172
return self.topo1.sample(ischeme, degree) + self.topo2.sample(ischeme, degree)
10981173

1174+
def elementwise_stack_last_axis(self, func: function.Array) -> function.Array:
1175+
return numpy.concatenate([self.topo1.elementwise_stack_last_axis(func), self.topo2.elementwise_stack_last_axis(func)], axis=-1)
1176+
10991177
def trim(self, levelset: function.Array, maxrefine: int, ndivisions: int = 8, name: str = 'trimmed', leveltopo: Optional[Topology] = None, *, arguments: Optional[_ArgDict] = None) -> Topology:
11001178
if leveltopo is not None:
11011179
return super().trim(levelset, maxrefine, ndivisions, name, leveltopo, arguments=arguments)
@@ -1245,6 +1323,10 @@ def basis(self, name: str, degree: Union[int, Sequence[int]], **kwargs) -> funct
12451323
def sample(self, ischeme: str, degree: int) -> Sample:
12461324
return self.topo1.sample(ischeme, degree) * self.topo2.sample(ischeme, degree)
12471325

1326+
def elementwise_stack_last_axis(self, func: function.Array) -> function.Array:
1327+
func = function.Array.cast(func)
1328+
return numpy.reshape(self.topo2.elementwise_stack_last_axis(self.topo1.elementwise_stack_last_axis(func)), (*func.shape, len(self)))
1329+
12481330

12491331
class _Take(_TensorialTopology):
12501332

@@ -1261,6 +1343,9 @@ def __init__(self, parent: Topology, indices: numpy.ndarray) -> None:
12611343
def sample(self, ischeme: str, degree: int) -> Sample:
12621344
return self.parent.sample(ischeme, degree).take_elements(self.indices)
12631345

1346+
def elementwise_stack_last_axis(self, func: function.Array) -> function.Array:
1347+
return numpy.take(self.parent.elementwise_stack_last_axis(func), self.indices, axis=-1)
1348+
12641349

12651350
class _WithGroupAliases(_TensorialTopology):
12661351

@@ -1307,6 +1392,9 @@ def basis(self, name: str, *args, **kwargs) -> function.Basis:
13071392
def sample(self, ischeme: str, degree: int) -> Sample:
13081393
return self.parent.sample(ischeme, degree)
13091394

1395+
def elementwise_stack_last_axis(self, func: function.Array) -> function.Array:
1396+
return self.parent.elementwise_stack_last_axis(func)
1397+
13101398
def refine_spaces_unchecked(self, spaces: FrozenSet[str]) -> Topology:
13111399
return _WithGroupAliases(self.parent.refine_spaces(spaces), self.vgroups, self.bgroups, self.igroups)
13121400

@@ -1326,6 +1414,22 @@ def interfaces_spaces_unchecked(self, spaces: FrozenSet[str]) -> Topology:
13261414
return _WithGroupAliases(self.parent.interfaces_spaces_unchecked(spaces), self.igroups, types.frozendict({}), types.frozendict({}))
13271415

13281416

1417+
class _ElementwiseStack(function.Array):
1418+
1419+
def __init__(self, func: function.Array, space: str, transforms: transformseq.Transforms, opposites: transformseq.Transforms):
1420+
self.func = func
1421+
self.space = space
1422+
self.transforms = transforms
1423+
self.opposites = opposites
1424+
self.nelems = len(self.transforms)
1425+
super().__init__((*func.shape, self.nelems), func.dtype, func.spaces - frozenset({space}), func.arguments)
1426+
1427+
def lower(self, args: function.LowerArgs) -> evaluable.Array:
1428+
index = evaluable.loop_index(f'_loop_{self.space}', evaluable.asarray(self.nelems))
1429+
func = self.func.lower(args | function.LowerArgs.for_space(self.space, (self.transforms, self.opposites), index))
1430+
return evaluable.loop_concatenate(evaluable.appendaxes(func, (evaluable.asarray(1),)), index)
1431+
1432+
13291433
class TransformChainsTopology(Topology):
13301434
'base class for topologies with transform chains'
13311435

@@ -1433,6 +1537,9 @@ def sample(self, ischeme, degree):
14331537
transforms += self.opposites,
14341538
return Sample.new(self.space, transforms, points)
14351539

1540+
def elementwise_stack_last_axis(self, func: function.Array) -> function.Array:
1541+
return _ElementwiseStack(function.Array.cast(func), self.space, self.transforms, self.opposites)
1542+
14361543
def _refined_by(self, refine):
14371544
fine = self.refined.transforms
14381545
indices0 = numpy.setdiff1d(numpy.arange(len(self)), refine, assume_unique=True)

0 commit comments

Comments
 (0)