Skip to content

EnsembleFunction and friends. #4025

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 16 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions firedrake/adjoint_utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,4 @@
from firedrake.adjoint_utils.solving import * # noqa: F401
from firedrake.adjoint_utils.mesh import * # noqa: F401
from firedrake.adjoint_utils.checkpointing import * # noqa: F401
from firedrake.adjoint_utils.ensemble_function import * # noqa: F401
77 changes: 77 additions & 0 deletions firedrake/adjoint_utils/ensemble_function.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
from pyadjoint.overloaded_type import OverloadedType
from firedrake.petsc import PETSc
from .checkpointing import disk_checkpointing

from functools import wraps

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing __all__


class EnsembleFunctionMixin(OverloadedType):
"""
Basic functionality for EnsembleFunction to be OverloadedTypes.
Note that currently no EnsembleFunction operations are taped.

Enables EnsembleFunction to do the following:
- Be a Control for a NumpyReducedFunctional (_ad_to_list and _ad_assign_numpy)
- Be used with pyadjoint TAO solver (_ad_{to,from}_petsc)
- Be used as a Control for Taylor tests (_ad_dot)
"""

@staticmethod
def _ad_annotate_init(init):
@wraps(init)
def wrapper(self, *args, **kwargs):
OverloadedType.__init__(self)
init(self, *args, **kwargs)
self._ad_add = self.__add__
self._ad_mul = self.__mul__
self._ad_iadd = self.__iadd__
self._ad_imul = self.__imul__
self._ad_copy = self.copy
return wrapper

@staticmethod
def _ad_to_list(m):
with m.vec_ro() as gvec:
lvec = PETSc.Vec().createSeq(gvec.size,
comm=PETSc.COMM_SELF)
PETSc.Scatter().toAll(gvec).scatter(
gvec, lvec, addv=PETSc.InsertMode.INSERT_VALUES)
return lvec.array_r.tolist()

@staticmethod
def _ad_assign_numpy(dst, src, offset):
with dst.vec_wo() as vec:
begin, end = vec.owner_range
vec.array[:] = src[offset + begin: offset + end]
offset += vec.size
return dst, offset

def _ad_dot(self, other, options=None):
local_dot = sum(uself._ad_dot(uother, options=options)
for uself, uother in zip(self.subfunctions,
other.subfunctions))
return self.ensemble.ensemble_comm.allreduce(local_dot)

def _ad_convert_riesz(self, value, options=None):
raise NotImplementedError

def _ad_create_checkpoint(self):
if disk_checkpointing():
raise NotImplementedError(
"Disk checkpointing not implemented for EnsembleFunctions")
else:
return self.copy()

def _ad_restore_at_checkpoint(self, checkpoint):
if type(checkpoint) is type(self):
return checkpoint
raise NotImplementedError(
"Disk checkpointing not implemented for EnsembleFunctions")

def _ad_from_petsc(self, vec):
with self.vec_wo as self_v:
vec.copy(self_v)

def _ad_to_petsc(self, vec=None):
with self.vec_ro as self_v:
return self_v.copy(vec or self._vec.duplicate())
3 changes: 3 additions & 0 deletions firedrake/ensemble/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from firedrake.ensemble.ensemble import * # noqa: F401
from firedrake.ensemble.ensemble_function import * # noqa: F401
from firedrake.ensemble.ensemble_functionspace import * # noqa: F401
12 changes: 12 additions & 0 deletions firedrake/ensemble.py → firedrake/ensemble/ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,18 @@ def __init__(self, comm, M, **kwargs):
assert self.comm.size == M
assert self.ensemble_comm.size == (size // M)

@property
def ensemble_size(self):
"""The number of ensemble members.
"""
return self.ensemble_comm.size

@property
def ensemble_rank(self):
"""The rank of the local ensemble member.
"""
return self.ensemble_comm.rank

def _check_function(self, f, g=None):
"""
Check if function f (and possibly a second function g) is a
Expand Down
287 changes: 287 additions & 0 deletions firedrake/ensemble/ensemble_function.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,287 @@
from firedrake.petsc import PETSc
from firedrake.ensemble.ensemble_functionspace import (
EnsembleFunctionSpaceBase, EnsembleFunctionSpace, EnsembleDualSpace)
from firedrake.adjoint_utils import EnsembleFunctionMixin
from firedrake.function import Function
from firedrake.norms import norm
from pyop2 import MixedDat

from functools import cached_property
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Feel free to ignore (we violate this everywhere) but conventionally standard library packages go above third party ones

from contextlib import contextmanager

__all__ = ("EnsembleFunction", "EnsembleCofunction")


class EnsembleFunctionBase(EnsembleFunctionMixin):
"""
A mixed (co)function defined on a :class:`firedrake.Ensemble`.
The subcomponents are distributed over the ensemble members, and
are specified locally in a :class:`firedrake.EnsembleFunctionSpace`.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure what you mean by "are specified locally"

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Anywhere I say "locally" I mean on each Ensemble.comm. I should either define this somewhere or be more explicit in each place I currently use "locally".


Parameters
----------

function_space : `firedrake.EnsembleFunctionSpace`.
The function space of the (co)function.

Notes
-----
Passing an `EnsembleDualSpace` to `EnsembleFunction`
will return an instance of :class:`firedrake.EnsembleCofunction`.

This class does not carry UFL symbolic information, unlike a
:class:`firedrake.Function`. UFL expressions can only be defined
locally on each ensemble member using a `firedrake.Function`
from `EnsembleFunction.subfunctions`.

See Also
--------
- Primal ensemble objects: :class:`firedrake.EnsembleFunctionSpace` and :class:`firedrake.EnsembleFunction`.
- Dual ensemble objects: :class:`firedrake.EnsembleDualSpace` and :class:`firedrake.EnsembleCofunction`.
"""

@PETSc.Log.EventDecorator()
@EnsembleFunctionMixin._ad_annotate_init
def __init__(self, function_space: EnsembleFunctionSpaceBase):
self._fs = function_space

# we hold all subcomponents on the local
# ensemble member in one big mixed function.
self._full_local_function = Function(function_space._full_local_space)

# create a Vec containing the data for all subcomponents on all
# ensemble members. Because we use the Vec of each local mixed
# function as the storage, if the data in the Function Vec
# is valid then the data in the EnsembleFunction Vec is valid.

with self._full_local_function.dat.vec as fvec:
n = function_space.nlocal_rank_dofs
N = function_space.nglobal_dofs
sizes = (n, N)
self._vec = PETSc.Vec().createWithArray(
fvec.array, size=sizes,
comm=function_space.global_comm)
self._vec.setFromOptions()

def function_space(self):
return self._fs

@cached_property
def subfunctions(self):
"""
The (co)functions on the local ensemble member.
"""
def local_function(i):
V = self._fs.local_spaces[i]
usubs = self._subcomponents(i)
if len(usubs) == 1:
dat = usubs[0].dat
else:
dat = MixedDat((u.dat for u in usubs))
return Function(V, val=dat)

return tuple(local_function(i)
for i in range(self._fs.nlocal_spaces))

def _subcomponents(self, i):
"""
Return the subfunctions of the local mixed function storage
corresponding to the i-th local function.
"""
return tuple(self._full_local_function.subfunctions[j]
for j in self._fs._component_indices(i))

@PETSc.Log.EventDecorator()
def riesz_representation(self, riesz_map="L2", **kwargs):
"""
Return the Riesz representation of this :class:`EnsembleFunction`
with respect to the given Riesz map.

Parameters
----------
riesz_map
The Riesz map to use (`l2`, `L2`, or `H1`). This can also be a callable.

kwargs
other arguments to be passed to the firedrake.riesz_map.
"""
riesz = EnsembleFunction(self._fs.dual())
for uself, uriesz in zip(self.subfunctions, riesz.subfunctions):
uriesz.assign(
uself.riesz_representation(
riesz_map=riesz_map, **kwargs))
return riesz

@PETSc.Log.EventDecorator()
def assign(self, other, subsets=None):
r"""Set the :class:`EnsembleFunction` to the value of another
:class:`EnsembleFunction` other.

Parameters
----------

other : :class:`EnsembleFunction`
The value to assign from.

subsets : Collection[Optional[:class:`pyop2.types.set.Subset`]]
One subset for each local :class:`Function`. None elements
will be ignored. The values of each local function will
only be assigned on the nodes on the corresponding subset.
"""
if type(other) is not type(self):
raise TypeError(
f"Cannot assign {type(self).__name__} from {type(other).__name__}")
for i in range(self._fs.nlocal_spaces):
self.subfunctions[i].assign(
other.subfunctions[i],
subset=subsets[i] if subsets else None)
return self

@PETSc.Log.EventDecorator()
def copy(self):
"""
Return a deep copy of the :class:`EnsembleFunction`.
"""
new = type(self)(self.function_space())
new.assign(self)
return new

@PETSc.Log.EventDecorator()
def zero(self, subsets=None):
"""
Set values to zero.

Parameters
----------

subsets : Collection[Optional[:class:`pyop2.types.set.Subset`]]
One subset for each local :class:`Function`. None elements
will be ignored. The values of each local function will
only be zeroed on the nodes on the corresponding subset.
"""
for i in range(self._fs.nlocal_spaces):
self.subfunctions[i].zero(
subset=subsets[i] if subsets else None)
return self

@PETSc.Log.EventDecorator()
def __iadd__(self, other):
for us, uo in zip(self.subfunctions, other.subfunctions):
us.assign(us + uo)
return self

@PETSc.Log.EventDecorator()
def __imul__(self, other):
if type(other) is type(self):
for us, uo in zip(self.subfunctions, other.subfunctions):
us.assign(us*uo)
else:
for us in self.subfunctions:
us *= other
return self

@PETSc.Log.EventDecorator()
def __add__(self, other):
new = self.copy()
new += other
return new
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not really a problem, but it seems strange to me to define __add__ in terms of __iadd__. I would do it the other way around.


@PETSc.Log.EventDecorator()
def __mul__(self, other):
new = self.copy()
new *= other
return new

@PETSc.Log.EventDecorator()
def __rmul__(self, other):
return self.__mul__(other)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this definitely safe to do?


@contextmanager
def vec(self):
"""
Context manager for the global PETSc Vec with read/write access.

It is invalid to access the Vec outside of a context manager.
"""
# _full_local_function.vec shares the same storage as _vec, so we need this
# nested context manager so that the data gets copied to/from
# the Function.dat storage and _vec.
# However, this copy is done without _vec knowing, so we have
# to manually increment the state.
with self._full_local_function.dat.vec:
self._vec.stateIncrease()
yield self._vec
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't really understand the need for this. Could you simply have

def vec(self):
    return self._fbuf.dat.vec

or

@contextlib.contextmanager
def vec(self):
    yield from self._fbuf.dat.vec

?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See the implementation notes I've added to the PR description. Let me know if they don't make sense and/or you think there's a better way of doing it.


@contextmanager
def vec_ro(self):
"""
Context manager for the global PETSc Vec with read only access.

It is invalid to access the Vec outside of a context manager.
"""
# _full_local_function.vec shares the same storage as _vec, so we need this
# nested context manager to make sure that the data gets copied
# to the Function.dat storage and _vec.
with self._full_local_function.dat.vec_ro:
self._vec.stateIncrease()
yield self._vec

@contextmanager
def vec_wo(self):
"""
Context manager for the global PETSc Vec with write only access.

It is invalid to access the Vec outside of a context manager.
"""
# _full_local_function.vec shares the same storage as _vec, so we need this
# nested context manager to make sure that the data gets copied
# from the Function.dat storage and _vec.
with self._full_local_function.dat.vec_wo:
yield self._vec


class EnsembleFunction(EnsembleFunctionBase):
"""
A mixed finite element Function distributed over an ensemble.

Parameters
----------

function_space : `EnsembleFunctionSpace`
The function space of the function.
"""
def __new__(cls, function_space: EnsembleFunctionSpaceBase):
if isinstance(function_space, EnsembleDualSpace):
return EnsembleCofunction(function_space)
return super().__new__(cls)

def __init__(self, function_space: EnsembleFunctionSpace):
if not isinstance(function_space, EnsembleFunctionSpace):
raise TypeError(
"EnsembleFunction must be created using an EnsembleFunctionSpace")
super().__init__(function_space)

def norm(self, *args, **kwargs):
"""Compute the norm of the function.

Any arguments are forwarded to `firedrake.norm`.
"""
return self._fs.ensemble_comm.allreduce(
sum(norm(u, *args, **kwargs) for u in self.subfunctions))


class EnsembleCofunction(EnsembleFunctionBase):
"""
A mixed finite element Cofunction distributed over an ensemble.

Parameters
----------

function_space : `EnsembleDualSpace`
The function space of the cofunction.
"""
def __init__(self, function_space: EnsembleDualSpace):
if not isinstance(function_space, EnsembleDualSpace):
raise TypeError(
"EnsembleCofunction must be created using an EnsembleDualSpace")
super().__init__(function_space)
Loading
Loading