Skip to content

Commit d9bdf14

Browse files
committed
fixes
1 parent fca4a3b commit d9bdf14

File tree

1 file changed

+28
-42
lines changed

1 file changed

+28
-42
lines changed

firedrake/interpolation.py

Lines changed: 28 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -249,8 +249,18 @@ def __init__(self, expr: Interpolate, bcs: Iterable[BCBase] | None = None):
249249
self.access = expr.options.access
250250

251251
@abc.abstractmethod
252-
def _build_callable(self, output=None) -> None:
252+
def _build_callable(self, output: Function | Cofunction | MatrixBase | None = None) -> None:
253253
"""Builds callable to perform interpolation. Stored in ``self.callable``.
254+
255+
If ``self.rank == 2``, then ``self.callable()`` must return an object with a ``handle``
256+
attribute that stores a PETSc matrix. If ``self.rank == 1``, then `self.callable()` must
257+
return a ``Function`` or ``Cofunction`` (in the forward and adjoint cases respectively).
258+
If ``self.rank == 0``, then ``self.callable()`` must return a number.
259+
260+
Parameters
261+
----------
262+
output : Function | Cofunction | MatrixBase | None, optional
263+
Optional tensor to store the result in, by default None
254264
"""
255265
pass
256266

@@ -268,8 +278,8 @@ def assemble(
268278
----------
269279
tensor : Function | Cofunction | MatrixBase, optional
270280
Pre-allocated storage to receive the interpolated result. For rank-2
271-
expressions this is expected to be a
272-
:class:`~firedrake.assemble.AssembledMatrix`-compatible object whose
281+
expressions this is expected to be a subclass of
282+
:class:`~firedrake.matrix.MatrixBase` whose
273283
``petscmat`` will be populated. For lower-rank expressions this is
274284
a :class:`~firedrake.Function` or :class:`~firedrake.Cofunction`.
275285
@@ -280,27 +290,24 @@ def assemble(
280290
interpolation.
281291
"""
282292
self._build_callable(output=tensor)
283-
assembled_interpolator = self.callable()
293+
result = self.callable()
284294
if self.rank == 2:
285295
# Assembling the operator
286296
assert isinstance(tensor, MatrixBase | None)
287297
res = tensor.petscmat if tensor else PETSc.Mat()
288298
# Get the interpolation matrix
289-
petsc_mat = assembled_interpolator.handle
299+
petsc_mat = result.handle
290300
if tensor:
291301
petsc_mat.copy(tensor.petscmat)
292302
else:
293303
res = petsc_mat
294304
return tensor or firedrake.AssembledMatrix(self.expr_args, self.bcs, res)
295305
else:
296306
assert isinstance(tensor, Function | Cofunction | None)
297-
if tensor:
298-
tensor.assign(assembled_interpolator)
307+
if tensor and isinstance(result, Function | Cofunction):
308+
tensor.assign(result)
299309
return tensor
300-
if self.rank == 0:
301-
return assembled_interpolator.dat.data.item()
302-
else:
303-
return assembled_interpolator
310+
return result
304311

305312

306313
class DofNotDefinedError(Exception):
@@ -485,7 +492,7 @@ def callable():
485492

486493
def callable():
487494
assemble(action(self.point_eval_input_ordering, f_point_eval),
488-
tensor=f_point_eval_input_ordering)
495+
tensor=f_point_eval_input_ordering)
489496

490497
# We assign these values to the output function
491498
if self.allow_missing_dofs and self.default_missing_val is None:
@@ -581,12 +588,6 @@ def _get_tensor(self) -> op2.Mat | Function | Cofunction:
581588
return f
582589

583590
def _build_callable(self, output=None) -> None:
584-
"""Construct the callable that performs the interpolation.
585-
586-
Returns
587-
-------
588-
Callable
589-
"""
590591
f = output or self._get_tensor()
591592
tensor = f if isinstance(f, op2.Mat) else f.dat
592593

@@ -628,28 +629,10 @@ def _build_callable(self, output=None) -> None:
628629
def callable(loops, f):
629630
for l in loops:
630631
l()
631-
return f
632+
return f.dat.data.item() if self.rank == 0 else f
632633

633634
self.callable = partial(callable, loops, f)
634635

635-
@PETSc.Log.EventDecorator()
636-
def _interpolate(self, output=None):
637-
"""Compute the interpolation.
638-
639-
For arguments, see :class:`.Interpolator`.
640-
"""
641-
assert self.rank < 2
642-
self._build_callable(output=output)
643-
assembled_interpolator = self.callable()
644-
if output:
645-
output.assign(assembled_interpolator)
646-
return output
647-
648-
if self.rank == 0:
649-
return assembled_interpolator.dat.data.item()
650-
else:
651-
return assembled_interpolator
652-
653636

654637
class VomOntoVomInterpolator(SameMeshInterpolator):
655638

@@ -1418,7 +1401,8 @@ def __iter__(self):
14181401

14191402
def _build_callable(self, output=None):
14201403
"""Assemble the operator."""
1421-
f = output or Function(self.expr_args[-1].function_space().dual())
1404+
V_dest = self.expr.function_space() or self.target_space
1405+
f = output or Function(V_dest)
14221406
if self.rank == 2:
14231407
shape = tuple(len(a.function_space()) for a in self.expr_args)
14241408
blocks = numpy.full(shape, PETSc.Mat(), dtype=object)
@@ -1427,14 +1411,16 @@ def _build_callable(self, output=None):
14271411
blocks[i] = self[i].callable().handle
14281412
petscmat = PETSc.Mat().createNest(blocks)
14291413
tensor = firedrake.AssembledMatrix(self.expr_args, self.bcs, petscmat)
1430-
self.callable = lambda: tensor.M
1414+
callable = lambda: tensor.M
14311415
elif self.rank == 1:
14321416
def callable():
14331417
for k, sub_tensor in enumerate(f.subfunctions):
14341418
sub_tensor.assign(sum(self[i].assemble() for i in self if i[0] == k))
14351419
return f
1436-
self.callable = callable
14371420
else:
1421+
assert self.rank == 0
14381422
def callable():
1439-
return sum(self[i].assemble() for i in self)
1440-
self.callable = callable
1423+
result = sum(self[i].assemble() for i in self)
1424+
assert isinstance(result, Number)
1425+
return result
1426+
self.callable = callable

0 commit comments

Comments
 (0)