@@ -249,8 +249,18 @@ def __init__(self, expr: Interpolate, bcs: Iterable[BCBase] | None = None):
249249 self .access = expr .options .access
250250
251251 @abc .abstractmethod
252- def _build_callable (self , output = None ) -> None :
252+ def _build_callable (self , output : Function | Cofunction | MatrixBase | None = None ) -> None :
253253 """Builds callable to perform interpolation. Stored in ``self.callable``.
254+
255+ If ``self.rank == 2``, then ``self.callable()`` must return an object with a ``handle``
256+ attribute that stores a PETSc matrix. If ``self.rank == 1``, then `self.callable()` must
257+ return a ``Function`` or ``Cofunction`` (in the forward and adjoint cases respectively).
258+ If ``self.rank == 0``, then ``self.callable()`` must return a number.
259+
260+ Parameters
261+ ----------
262+ output : Function | Cofunction | MatrixBase | None, optional
263+ Optional tensor to store the result in, by default None
254264 """
255265 pass
256266
@@ -268,8 +278,8 @@ def assemble(
268278 ----------
269279 tensor : Function | Cofunction | MatrixBase, optional
270280 Pre-allocated storage to receive the interpolated result. For rank-2
271- expressions this is expected to be a
272- :class:`~firedrake.assemble.AssembledMatrix`-compatible object whose
281+ expressions this is expected to be a subclass of
282+ :class:`~firedrake.matrix.MatrixBase` whose
273283 ``petscmat`` will be populated. For lower-rank expressions this is
274284 a :class:`~firedrake.Function` or :class:`~firedrake.Cofunction`.
275285
@@ -280,27 +290,24 @@ def assemble(
280290 interpolation.
281291 """
282292 self ._build_callable (output = tensor )
283- assembled_interpolator = self .callable ()
293+ result = self .callable ()
284294 if self .rank == 2 :
285295 # Assembling the operator
286296 assert isinstance (tensor , MatrixBase | None )
287297 res = tensor .petscmat if tensor else PETSc .Mat ()
288298 # Get the interpolation matrix
289- petsc_mat = assembled_interpolator .handle
299+ petsc_mat = result .handle
290300 if tensor :
291301 petsc_mat .copy (tensor .petscmat )
292302 else :
293303 res = petsc_mat
294304 return tensor or firedrake .AssembledMatrix (self .expr_args , self .bcs , res )
295305 else :
296306 assert isinstance (tensor , Function | Cofunction | None )
297- if tensor :
298- tensor .assign (assembled_interpolator )
307+ if tensor and isinstance ( result , Function | Cofunction ) :
308+ tensor .assign (result )
299309 return tensor
300- if self .rank == 0 :
301- return assembled_interpolator .dat .data .item ()
302- else :
303- return assembled_interpolator
310+ return result
304311
305312
306313class DofNotDefinedError (Exception ):
@@ -485,7 +492,7 @@ def callable():
485492
486493 def callable ():
487494 assemble (action (self .point_eval_input_ordering , f_point_eval ),
488- tensor = f_point_eval_input_ordering )
495+ tensor = f_point_eval_input_ordering )
489496
490497 # We assign these values to the output function
491498 if self .allow_missing_dofs and self .default_missing_val is None :
@@ -581,12 +588,6 @@ def _get_tensor(self) -> op2.Mat | Function | Cofunction:
581588 return f
582589
583590 def _build_callable (self , output = None ) -> None :
584- """Construct the callable that performs the interpolation.
585-
586- Returns
587- -------
588- Callable
589- """
590591 f = output or self ._get_tensor ()
591592 tensor = f if isinstance (f , op2 .Mat ) else f .dat
592593
@@ -628,28 +629,10 @@ def _build_callable(self, output=None) -> None:
628629 def callable (loops , f ):
629630 for l in loops :
630631 l ()
631- return f
632+ return f . dat . data . item () if self . rank == 0 else f
632633
633634 self .callable = partial (callable , loops , f )
634635
635- @PETSc .Log .EventDecorator ()
636- def _interpolate (self , output = None ):
637- """Compute the interpolation.
638-
639- For arguments, see :class:`.Interpolator`.
640- """
641- assert self .rank < 2
642- self ._build_callable (output = output )
643- assembled_interpolator = self .callable ()
644- if output :
645- output .assign (assembled_interpolator )
646- return output
647-
648- if self .rank == 0 :
649- return assembled_interpolator .dat .data .item ()
650- else :
651- return assembled_interpolator
652-
653636
654637class VomOntoVomInterpolator (SameMeshInterpolator ):
655638
@@ -1418,7 +1401,8 @@ def __iter__(self):
14181401
14191402 def _build_callable (self , output = None ):
14201403 """Assemble the operator."""
1421- f = output or Function (self .expr_args [- 1 ].function_space ().dual ())
1404+ V_dest = self .expr .function_space () or self .target_space
1405+ f = output or Function (V_dest )
14221406 if self .rank == 2 :
14231407 shape = tuple (len (a .function_space ()) for a in self .expr_args )
14241408 blocks = numpy .full (shape , PETSc .Mat (), dtype = object )
@@ -1427,14 +1411,16 @@ def _build_callable(self, output=None):
14271411 blocks [i ] = self [i ].callable ().handle
14281412 petscmat = PETSc .Mat ().createNest (blocks )
14291413 tensor = firedrake .AssembledMatrix (self .expr_args , self .bcs , petscmat )
1430- self . callable = lambda : tensor .M
1414+ callable = lambda : tensor .M
14311415 elif self .rank == 1 :
14321416 def callable ():
14331417 for k , sub_tensor in enumerate (f .subfunctions ):
14341418 sub_tensor .assign (sum (self [i ].assemble () for i in self if i [0 ] == k ))
14351419 return f
1436- self .callable = callable
14371420 else :
1421+ assert self .rank == 0
14381422 def callable ():
1439- return sum (self [i ].assemble () for i in self )
1440- self .callable = callable
1423+ result = sum (self [i ].assemble () for i in self )
1424+ assert isinstance (result , Number )
1425+ return result
1426+ self .callable = callable
0 commit comments