-
Notifications
You must be signed in to change notification settings - Fork 166
EnsembleFunction
and friends.
#4025
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
366fb41
8a8472f
77b7f8f
3d7c357
3c932d9
44166cd
a7192bd
fdaf00f
889ff11
e6b5698
65b90c4
f9c6538
7565a62
ee2f8b8
8f73d68
be0285e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,77 @@ | ||
from pyadjoint.overloaded_type import OverloadedType | ||
from firedrake.petsc import PETSc | ||
from .checkpointing import disk_checkpointing | ||
|
||
from functools import wraps | ||
|
||
|
||
class EnsembleFunctionMixin(OverloadedType): | ||
""" | ||
Basic functionality for EnsembleFunction to be OverloadedTypes. | ||
Note that currently no EnsembleFunction operations are taped. | ||
|
||
Enables EnsembleFunction to do the following: | ||
- Be a Control for a NumpyReducedFunctional (_ad_to_list and _ad_assign_numpy) | ||
- Be used with pyadjoint TAO solver (_ad_{to,from}_petsc) | ||
- Be used as a Control for Taylor tests (_ad_dot) | ||
""" | ||
|
||
@staticmethod | ||
def _ad_annotate_init(init): | ||
@wraps(init) | ||
def wrapper(self, *args, **kwargs): | ||
OverloadedType.__init__(self) | ||
init(self, *args, **kwargs) | ||
self._ad_add = self.__add__ | ||
self._ad_mul = self.__mul__ | ||
self._ad_iadd = self.__iadd__ | ||
self._ad_imul = self.__imul__ | ||
self._ad_copy = self.copy | ||
return wrapper | ||
|
||
@staticmethod | ||
def _ad_to_list(m): | ||
with m.vec_ro() as gvec: | ||
lvec = PETSc.Vec().createSeq(gvec.size, | ||
comm=PETSc.COMM_SELF) | ||
PETSc.Scatter().toAll(gvec).scatter( | ||
gvec, lvec, addv=PETSc.InsertMode.INSERT_VALUES) | ||
return lvec.array_r.tolist() | ||
|
||
@staticmethod | ||
def _ad_assign_numpy(dst, src, offset): | ||
with dst.vec_wo() as vec: | ||
begin, end = vec.owner_range | ||
vec.array[:] = src[offset + begin: offset + end] | ||
offset += vec.size | ||
return dst, offset | ||
|
||
def _ad_dot(self, other, options=None): | ||
local_dot = sum(uself._ad_dot(uother, options=options) | ||
for uself, uother in zip(self.subfunctions, | ||
other.subfunctions)) | ||
return self.ensemble.ensemble_comm.allreduce(local_dot) | ||
|
||
def _ad_convert_riesz(self, value, options=None): | ||
raise NotImplementedError | ||
|
||
def _ad_create_checkpoint(self): | ||
if disk_checkpointing(): | ||
raise NotImplementedError( | ||
"Disk checkpointing not implemented for EnsembleFunctions") | ||
else: | ||
return self.copy() | ||
|
||
def _ad_restore_at_checkpoint(self, checkpoint): | ||
if type(checkpoint) is type(self): | ||
return checkpoint | ||
raise NotImplementedError( | ||
"Disk checkpointing not implemented for EnsembleFunctions") | ||
|
||
def _ad_from_petsc(self, vec): | ||
with self.vec_wo as self_v: | ||
vec.copy(self_v) | ||
|
||
def _ad_to_petsc(self, vec=None): | ||
with self.vec_ro as self_v: | ||
return self_v.copy(vec or self._vec.duplicate()) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from firedrake.ensemble.ensemble import * # noqa: F401 | ||
from firedrake.ensemble.ensemble_function import * # noqa: F401 | ||
from firedrake.ensemble.ensemble_functionspace import * # noqa: F401 |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,287 @@ | ||
from firedrake.petsc import PETSc | ||
from firedrake.ensemble.ensemble_functionspace import ( | ||
EnsembleFunctionSpaceBase, EnsembleFunctionSpace, EnsembleDualSpace) | ||
from firedrake.adjoint_utils import EnsembleFunctionMixin | ||
from firedrake.function import Function | ||
from firedrake.norms import norm | ||
from pyop2 import MixedDat | ||
|
||
from functools import cached_property | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Feel free to ignore (we violate this everywhere) but conventionally standard library packages go above third party ones |
||
from contextlib import contextmanager | ||
|
||
__all__ = ("EnsembleFunction", "EnsembleCofunction") | ||
|
||
|
||
class EnsembleFunctionBase(EnsembleFunctionMixin): | ||
""" | ||
A mixed (co)function defined on a :class:`firedrake.Ensemble`. | ||
The subcomponents are distributed over the ensemble members, and | ||
are specified locally in a :class:`firedrake.EnsembleFunctionSpace`. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not sure what you mean by "are specified locally" There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Anywhere I say "locally" I mean on each |
||
|
||
Parameters | ||
---------- | ||
|
||
function_space : `firedrake.EnsembleFunctionSpace`. | ||
The function space of the (co)function. | ||
|
||
Notes | ||
----- | ||
Passing an `EnsembleDualSpace` to `EnsembleFunction` | ||
will return an instance of :class:`firedrake.EnsembleCofunction`. | ||
|
||
This class does not carry UFL symbolic information, unlike a | ||
:class:`firedrake.Function`. UFL expressions can only be defined | ||
locally on each ensemble member using a `firedrake.Function` | ||
from `EnsembleFunction.subfunctions`. | ||
|
||
See Also | ||
-------- | ||
- Primal ensemble objects: :class:`firedrake.EnsembleFunctionSpace` and :class:`firedrake.EnsembleFunction`. | ||
- Dual ensemble objects: :class:`firedrake.EnsembleDualSpace` and :class:`firedrake.EnsembleCofunction`. | ||
""" | ||
|
||
@PETSc.Log.EventDecorator() | ||
@EnsembleFunctionMixin._ad_annotate_init | ||
def __init__(self, function_space: EnsembleFunctionSpaceBase): | ||
self._fs = function_space | ||
|
||
# we hold all subcomponents on the local | ||
# ensemble member in one big mixed function. | ||
self._full_local_function = Function(function_space._full_local_space) | ||
|
||
# create a Vec containing the data for all subcomponents on all | ||
# ensemble members. Because we use the Vec of each local mixed | ||
# function as the storage, if the data in the Function Vec | ||
# is valid then the data in the EnsembleFunction Vec is valid. | ||
|
||
with self._full_local_function.dat.vec as fvec: | ||
n = function_space.nlocal_rank_dofs | ||
N = function_space.nglobal_dofs | ||
sizes = (n, N) | ||
self._vec = PETSc.Vec().createWithArray( | ||
fvec.array, size=sizes, | ||
comm=function_space.global_comm) | ||
self._vec.setFromOptions() | ||
|
||
def function_space(self): | ||
return self._fs | ||
|
||
@cached_property | ||
def subfunctions(self): | ||
""" | ||
The (co)functions on the local ensemble member. | ||
""" | ||
def local_function(i): | ||
V = self._fs.local_spaces[i] | ||
usubs = self._subcomponents(i) | ||
if len(usubs) == 1: | ||
dat = usubs[0].dat | ||
else: | ||
dat = MixedDat((u.dat for u in usubs)) | ||
return Function(V, val=dat) | ||
|
||
return tuple(local_function(i) | ||
for i in range(self._fs.nlocal_spaces)) | ||
|
||
def _subcomponents(self, i): | ||
""" | ||
Return the subfunctions of the local mixed function storage | ||
corresponding to the i-th local function. | ||
""" | ||
return tuple(self._full_local_function.subfunctions[j] | ||
for j in self._fs._component_indices(i)) | ||
|
||
@PETSc.Log.EventDecorator() | ||
def riesz_representation(self, riesz_map="L2", **kwargs): | ||
""" | ||
Return the Riesz representation of this :class:`EnsembleFunction` | ||
with respect to the given Riesz map. | ||
|
||
Parameters | ||
---------- | ||
riesz_map | ||
The Riesz map to use (`l2`, `L2`, or `H1`). This can also be a callable. | ||
|
||
kwargs | ||
other arguments to be passed to the firedrake.riesz_map. | ||
""" | ||
riesz = EnsembleFunction(self._fs.dual()) | ||
for uself, uriesz in zip(self.subfunctions, riesz.subfunctions): | ||
uriesz.assign( | ||
uself.riesz_representation( | ||
riesz_map=riesz_map, **kwargs)) | ||
return riesz | ||
|
||
@PETSc.Log.EventDecorator() | ||
def assign(self, other, subsets=None): | ||
r"""Set the :class:`EnsembleFunction` to the value of another | ||
:class:`EnsembleFunction` other. | ||
|
||
Parameters | ||
---------- | ||
|
||
other : :class:`EnsembleFunction` | ||
The value to assign from. | ||
|
||
subsets : Collection[Optional[:class:`pyop2.types.set.Subset`]] | ||
One subset for each local :class:`Function`. None elements | ||
will be ignored. The values of each local function will | ||
only be assigned on the nodes on the corresponding subset. | ||
""" | ||
if type(other) is not type(self): | ||
raise TypeError( | ||
f"Cannot assign {type(self).__name__} from {type(other).__name__}") | ||
for i in range(self._fs.nlocal_spaces): | ||
self.subfunctions[i].assign( | ||
other.subfunctions[i], | ||
subset=subsets[i] if subsets else None) | ||
return self | ||
|
||
@PETSc.Log.EventDecorator() | ||
def copy(self): | ||
""" | ||
Return a deep copy of the :class:`EnsembleFunction`. | ||
""" | ||
new = type(self)(self.function_space()) | ||
new.assign(self) | ||
return new | ||
|
||
@PETSc.Log.EventDecorator() | ||
def zero(self, subsets=None): | ||
""" | ||
Set values to zero. | ||
|
||
Parameters | ||
---------- | ||
|
||
subsets : Collection[Optional[:class:`pyop2.types.set.Subset`]] | ||
One subset for each local :class:`Function`. None elements | ||
will be ignored. The values of each local function will | ||
only be zeroed on the nodes on the corresponding subset. | ||
""" | ||
for i in range(self._fs.nlocal_spaces): | ||
self.subfunctions[i].zero( | ||
subset=subsets[i] if subsets else None) | ||
return self | ||
|
||
@PETSc.Log.EventDecorator() | ||
def __iadd__(self, other): | ||
for us, uo in zip(self.subfunctions, other.subfunctions): | ||
us.assign(us + uo) | ||
return self | ||
|
||
@PETSc.Log.EventDecorator() | ||
def __imul__(self, other): | ||
if type(other) is type(self): | ||
for us, uo in zip(self.subfunctions, other.subfunctions): | ||
us.assign(us*uo) | ||
else: | ||
for us in self.subfunctions: | ||
us *= other | ||
return self | ||
|
||
@PETSc.Log.EventDecorator() | ||
def __add__(self, other): | ||
new = self.copy() | ||
new += other | ||
return new | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not really a problem, but it seems strange to me to define |
||
|
||
@PETSc.Log.EventDecorator() | ||
def __mul__(self, other): | ||
new = self.copy() | ||
new *= other | ||
return new | ||
|
||
@PETSc.Log.EventDecorator() | ||
def __rmul__(self, other): | ||
return self.__mul__(other) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this definitely safe to do? |
||
|
||
@contextmanager | ||
def vec(self): | ||
connorjward marked this conversation as resolved.
Show resolved
Hide resolved
|
||
""" | ||
Context manager for the global PETSc Vec with read/write access. | ||
|
||
It is invalid to access the Vec outside of a context manager. | ||
""" | ||
# _full_local_function.vec shares the same storage as _vec, so we need this | ||
# nested context manager so that the data gets copied to/from | ||
# the Function.dat storage and _vec. | ||
# However, this copy is done without _vec knowing, so we have | ||
# to manually increment the state. | ||
with self._full_local_function.dat.vec: | ||
self._vec.stateIncrease() | ||
yield self._vec | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't really understand the need for this. Could you simply have def vec(self):
return self._fbuf.dat.vec or @contextlib.contextmanager
def vec(self):
yield from self._fbuf.dat.vec ? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. See the implementation notes I've added to the PR description. Let me know if they don't make sense and/or you think there's a better way of doing it. |
||
|
||
@contextmanager | ||
def vec_ro(self): | ||
""" | ||
Context manager for the global PETSc Vec with read only access. | ||
|
||
It is invalid to access the Vec outside of a context manager. | ||
""" | ||
# _full_local_function.vec shares the same storage as _vec, so we need this | ||
# nested context manager to make sure that the data gets copied | ||
# to the Function.dat storage and _vec. | ||
with self._full_local_function.dat.vec_ro: | ||
self._vec.stateIncrease() | ||
yield self._vec | ||
|
||
@contextmanager | ||
def vec_wo(self): | ||
""" | ||
Context manager for the global PETSc Vec with write only access. | ||
|
||
It is invalid to access the Vec outside of a context manager. | ||
""" | ||
# _full_local_function.vec shares the same storage as _vec, so we need this | ||
# nested context manager to make sure that the data gets copied | ||
# from the Function.dat storage and _vec. | ||
with self._full_local_function.dat.vec_wo: | ||
yield self._vec | ||
|
||
|
||
class EnsembleFunction(EnsembleFunctionBase): | ||
""" | ||
A mixed finite element Function distributed over an ensemble. | ||
|
||
Parameters | ||
---------- | ||
|
||
function_space : `EnsembleFunctionSpace` | ||
The function space of the function. | ||
""" | ||
def __new__(cls, function_space: EnsembleFunctionSpaceBase): | ||
if isinstance(function_space, EnsembleDualSpace): | ||
return EnsembleCofunction(function_space) | ||
return super().__new__(cls) | ||
|
||
def __init__(self, function_space: EnsembleFunctionSpace): | ||
if not isinstance(function_space, EnsembleFunctionSpace): | ||
raise TypeError( | ||
"EnsembleFunction must be created using an EnsembleFunctionSpace") | ||
super().__init__(function_space) | ||
|
||
def norm(self, *args, **kwargs): | ||
"""Compute the norm of the function. | ||
|
||
Any arguments are forwarded to `firedrake.norm`. | ||
""" | ||
return self._fs.ensemble_comm.allreduce( | ||
sum(norm(u, *args, **kwargs) for u in self.subfunctions)) | ||
|
||
|
||
class EnsembleCofunction(EnsembleFunctionBase): | ||
""" | ||
A mixed finite element Cofunction distributed over an ensemble. | ||
|
||
Parameters | ||
---------- | ||
|
||
function_space : `EnsembleDualSpace` | ||
The function space of the cofunction. | ||
""" | ||
def __init__(self, function_space: EnsembleDualSpace): | ||
if not isinstance(function_space, EnsembleDualSpace): | ||
raise TypeError( | ||
"EnsembleCofunction must be created using an EnsembleDualSpace") | ||
super().__init__(function_space) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missing
__all__