Skip to content

Commit 087d2d5

Browse files
Pass kwargs to Cofunction.interpolate (#4587)
* Pass kwargs to Cofunction.interpolate --------- Co-authored-by: Connor Ward <[email protected]>
1 parent f035f25 commit 087d2d5

File tree

2 files changed

+50
-43
lines changed

2 files changed

+50
-43
lines changed

firedrake/cofunction.py

Lines changed: 27 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -308,14 +308,33 @@ def __imul__(self, expr):
308308
return self
309309
return NotImplemented
310310

311-
def interpolate(self, expression):
312-
r"""Interpolate an expression onto this :class:`Cofunction`.
313-
314-
:param expression: a UFL expression to interpolate
315-
:returns: this :class:`firedrake.cofunction.Cofunction` object"""
316-
from firedrake import interpolation
317-
interp = interpolation.Interpolate(ufl_expr.Argument(self.function_space().dual(), 0), expression)
318-
return firedrake.assemble(interp, tensor=self)
311+
@PETSc.Log.EventDecorator()
312+
def interpolate(self,
313+
expression: ufl.BaseForm,
314+
ad_block_tag: str | None = None,
315+
**kwargs):
316+
"""Interpolate a dual expression onto this :class:`Cofunction`.
317+
318+
Parameters
319+
----------
320+
expression
321+
A dual UFL expression to interpolate.
322+
ad_block_tag
323+
An optional string for tagging the resulting assemble
324+
block on the Pyadjoint tape.
325+
**kwargs
326+
Any extra kwargs are passed on to the interpolate function.
327+
For details see `firedrake.interpolation.interpolate`.
328+
329+
Returns
330+
-------
331+
firedrake.cofunction.Cofunction
332+
Returns `self`
333+
"""
334+
from firedrake import interpolation, assemble
335+
v, = self.arguments()
336+
interp = interpolation.Interpolate(v, expression, **kwargs)
337+
return assemble(interp, tensor=self, ad_block_tag=ad_block_tag)
319338

320339
@property
321340
def cell_set(self):

firedrake/function.py

Lines changed: 23 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -360,43 +360,31 @@ def function_space(self):
360360
return self._function_space
361361

362362
@PETSc.Log.EventDecorator()
363-
def interpolate(
364-
self,
365-
expression,
366-
subset=None,
367-
allow_missing_dofs=False,
368-
default_missing_val=None,
369-
ad_block_tag=None
370-
):
371-
r"""Interpolate an expression onto this :class:`Function`.
372-
373-
:param expression: a UFL expression to interpolate
374-
:kwarg subset: An optional :class:`pyop2.types.set.Subset` to apply the
375-
interpolation over. Cannot, at present, be used when interpolating
376-
across meshes unless the target mesh is a :func:`.VertexOnlyMesh`.
377-
:kwarg allow_missing_dofs: For interpolation across meshes: allow
378-
degrees of freedom (aka DoFs/nodes) in the target mesh that cannot be
379-
defined on the source mesh. For example, where nodes are point
380-
evaluations, points in the target mesh that are not in the source mesh.
381-
When ``False`` this raises a ``ValueError`` should this occur. When
382-
``True`` the corresponding values are set to zero or to the value
383-
``default_missing_val`` if given. Ignored if interpolating within the
384-
same mesh or onto a :func:`.VertexOnlyMesh` (the behaviour of a
385-
:func:`.VertexOnlyMesh` in this scenario is, at present, set when
386-
it is created).
387-
:kwarg default_missing_val: For interpolation across meshes: the optional
388-
value to assign to DoFs in the target mesh that are outside the source
389-
mesh. If this is not set then zero is used. Ignored if interpolating
390-
within the same mesh or onto a :func:`.VertexOnlyMesh`.
391-
:kwarg ad_block_tag: An optional string for tagging the resulting assemble block on
392-
the Pyadjoint tape.
393-
:returns: this :class:`Function` object"""
363+
def interpolate(self,
364+
expression: ufl.classes.Expr,
365+
ad_block_tag: str | None = None,
366+
**kwargs):
367+
"""Interpolate an expression onto this :class:`Function`.
368+
369+
Parameters
370+
----------
371+
expression
372+
A UFL expression to interpolate.
373+
ad_block_tag
374+
An optional string for tagging the resulting assemble
375+
block on the Pyadjoint tape.
376+
**kwargs
377+
Any extra kwargs are passed on to the interpolate function.
378+
For details see `firedrake.interpolation.interpolate`.
379+
380+
Returns
381+
-------
382+
firedrake.function.Function
383+
Returns `self`
384+
"""
394385
from firedrake import interpolation, assemble
395386
V = self.function_space()
396-
interp = interpolation.Interpolate(expression, V,
397-
subset=subset,
398-
allow_missing_dofs=allow_missing_dofs,
399-
default_missing_val=default_missing_val)
387+
interp = interpolation.Interpolate(expression, V, **kwargs)
400388
return assemble(interp, tensor=self, ad_block_tag=ad_block_tag)
401389

402390
def zero(self, subset=None):

0 commit comments

Comments
 (0)