Skip to content

Commit 0be0568

Browse files
committed
Start to move Hessian evaluation into NonlinearVariationalSolveBlock
1 parent bbb6817 commit 0be0568

File tree

2 files changed

+63
-10
lines changed

2 files changed

+63
-10
lines changed

firedrake/adjoint_utils/blocks/solving.py

Lines changed: 30 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ class Solver(Enum):
3030
FORWARD = 0
3131
ADJOINT = 1
3232
TLM = 2
33+
HESSIAN = 3
3334

3435

3536
class GenericSolveBlock(Block):
@@ -221,6 +222,9 @@ def _assemble_and_solve_adj_eq(self, dFdu_adj_form, dJdu, compute_bdy):
221222

222223
return adj_sol, adj_sol_bdy
223224

225+
def _hessian_solve(self, *args):
226+
return self._assemble_and_solve_adj_eq(*args)
227+
224228
def _compute_adj_bdy(self, adj_sol, adj_sol_bdy, dFdu_adj_form, dJdu):
225229
adj_sol_bdy = firedrake.assemble(dJdu - firedrake.action(dFdu_adj_form, adj_sol))
226230
return adj_sol_bdy.riesz_representation("l2")
@@ -379,8 +383,7 @@ def _assemble_and_solve_soa_eq(self, dFdu_form, adj_sol, hessian_input,
379383
b = self._assemble_soa_eq_rhs(dFdu_form, adj_sol, hessian_input,
380384
d2Fdu2)
381385
dFdu_form = firedrake.adjoint(dFdu_form)
382-
adj_sol2, adj_sol2_bdy = self._assemble_and_solve_adj_eq(dFdu_form, b,
383-
compute_bdy)
386+
adj_sol2, adj_sol2_bdy = self._hessian_solve(dFdu_form, b, compute_bdy)
384387
if self.adj2_cb is not None:
385388
self.adj2_cb(adj_sol2)
386389
if self.adj2_bdy_cb is not None and compute_bdy:
@@ -679,6 +682,22 @@ def _adjoint_solve(self, dJdu, compute_bdy):
679682
u_sol, adj_sol_bdy, jac_adj, dJdu_copy)
680683
return u_sol, adj_sol_bdy
681684

685+
def _hessian_solve(self, adj_form, rhs, compute_bdy):
686+
# self._ad_solver_replace_forms(Solver.HESSIAN)
687+
# self._ad_solvers["hessian_lvs"].invalidate_jacobian()
688+
self._ad_solvers["hessian_lvs"]._problem.F._components[1].assign(rhs)
689+
self._ad_solvers["hessian_lvs"].solve()
690+
u_sol = self._ad_solvers["hessian_lvs"]._problem.u
691+
692+
adj_sol_bdy = None
693+
if compute_bdy:
694+
jac_adj = self._ad_solvers["hessian_lvs"]._problem.J
695+
adj_sol_bdy = self._compute_adj_bdy(
696+
u_sol, adj_sol_bdy, jac_adj, rhs.copy()
697+
)
698+
699+
return u_sol, adj_sol_bdy
700+
682701
def _ad_assign_map(self, form, solver):
683702
if solver == Solver.FORWARD:
684703
count_map = self._ad_solvers["forward_nlvs"]._problem._ad_count_map
@@ -697,8 +716,10 @@ def _ad_assign_map(self, form, solver):
697716
firedrake.Cofunction)):
698717
coeff_count = coeff.count()
699718
if coeff_count in form_ad_count_map:
700-
assign_map[form_ad_count_map[coeff_count]] = \
701-
block_variable.saved_output
719+
if solver == Solver.HESSIAN:
720+
assign_map[form_ad_count_map[coeff_count]] = block_variable.tlm_value
721+
else:
722+
assign_map[form_ad_count_map[coeff_count]] = block_variable.saved_output
702723

703724
if (
704725
solver == Solver.ADJOINT
@@ -709,6 +730,7 @@ def _ad_assign_map(self, form, solver):
709730
if coeff_count in form_ad_count_map:
710731
assign_map[form_ad_count_map[coeff_count]] = \
711732
block_variable.saved_output
733+
712734
return assign_map
713735

714736
def _ad_assign_coefficients(self, form, solver):
@@ -728,6 +750,10 @@ def _ad_solver_replace_forms(self, solver=Solver.FORWARD):
728750
self._ad_assign_coefficients(
729751
self._ad_solvers["tlm_lvs"]._problem.J, solver
730752
)
753+
elif solver == Solver.HESSIAN:
754+
self._ad_assign_coefficients(
755+
self._ad_solvers["hessian_lvs"]._problem.J, solver
756+
)
731757

732758
def prepare_evaluate_adj(self, inputs, adj_inputs, relevant_dependencies):
733759
compute_bdy = self._should_compute_boundary_adjoint(
@@ -851,11 +877,6 @@ def evaluate_tlm_component(self, inputs, tlm_inputs, block_variable, idx,
851877

852878
self._ad_solvers["tlm_lvs"].solve()
853879
return self._ad_solvers["tlm_lvs"]._problem.u
854-
# return self._assemble_and_solve_tlm_eq(
855-
# firedrake.assemble(dFdu, bcs=bcs, **self.assemble_kwargs),
856-
# dFdm, dudm, bcs
857-
# )
858-
859880

860881

861882
class ProjectBlock(SolveVarFormBlock):

firedrake/adjoint_utils/variational_solver.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from firedrake.adjoint_utils.blocks import NonlinearVariationalSolveBlock
55
from firedrake.ufl_expr import derivative, adjoint
66
from ufl import replace
7+
from ufl.algorithms import expand_derivatives
78

89

910
class NonlinearVariationalProblemMixin:
@@ -17,6 +18,7 @@ def wrapper(self, *args, **kwargs):
1718
self._ad_u = self.u_restrict
1819
self._ad_bcs = self.bcs
1920
self._ad_J = self.J
21+
2022
try:
2123
# Some forms (e.g. SLATE tensors) are not currently
2224
# differentiable.
@@ -27,8 +29,10 @@ def wrapper(self, *args, **kwargs):
2729
# Try again without expanding derivatives,
2830
# as dFdu might have been simplied to an empty Form
2931
self._ad_adj_F = adjoint(dFdu, derivatives_expanded=True)
32+
3033
except (TypeError, NotImplementedError):
3134
self._ad_adj_F = None
35+
3236
self._ad_kwargs = {'Jp': self.Jp, 'form_compiler_parameters': self.form_compiler_parameters, 'is_linear': self.is_linear}
3337
self._ad_count_map = {}
3438
return wrapper
@@ -49,7 +53,8 @@ def wrapper(self, problem, *args, **kwargs):
4953
self._ad_args = args
5054
self._ad_kwargs = kwargs
5155
self._ad_solvers = {"forward_nlvs": None, "adjoint_lvs": None,
52-
"recompute_count": 0, "tlm_lvs": None}
56+
"recompute_count": 0, "tlm_lvs": None,
57+
"hessian_lvs": None}
5358
self._ad_adj_cache = {}
5459

5560
return wrapper
@@ -100,6 +105,12 @@ def wrapper(self, **kwargs):
100105
if self._ad_problem._constant_jacobian:
101106
self._ad_solvers["update_adjoint"] = False
102107

108+
if not self._ad_solvers["hessian_lvs"]:
109+
with stop_annotating():
110+
self._ad_solvers["hessian_lvs"] = LinearVariationalSolver(
111+
self._ad_hessian_lvs_problem(block, problem._ad_adj_F),
112+
)
113+
103114
if not self._ad_solvers["tlm_lvs"]:
104115
with stop_annotating():
105116
self._ad_solvers["tlm_lvs"] = LinearVariationalSolver(
@@ -168,6 +179,27 @@ def _ad_adj_lvs_problem(self, block, adj_F):
168179
lvp._ad_count_map_update(_ad_count_map)
169180
return lvp
170181

182+
@no_annotations
183+
def _ad_hessian_lvs_problem(self, block, adj_dFdu):
184+
from firedrake import Function, Cofunction, LinearVariationalProblem
185+
186+
bcs = block._homogenize_bcs()
187+
adj_sol = Function(block.function_space)
188+
right_hand_side = Cofunction(block.function_space.dual())
189+
tmp_problem = LinearVariationalProblem(
190+
adj_dFdu, right_hand_side, adj_sol, bcs=bcs,
191+
constant_jacobian=self._ad_problem._constant_jacobian)
192+
193+
_ad_count_map, J_replace_map, _ = self._build_count_map(
194+
adj_dFdu, block._dependencies,
195+
)
196+
lvp = LinearVariationalProblem(
197+
replace(tmp_problem.J, J_replace_map), right_hand_side, adj_sol,
198+
bcs=tmp_problem.bcs,
199+
constant_jacobian=self._ad_problem._constant_jacobian)
200+
lvp._ad_count_map_update(_ad_count_map)
201+
return lvp
202+
171203
@no_annotations
172204
def _ad_tlm_lvs_problem(self, block, F, u):
173205
from firedrake import Function, Cofunction, LinearVariationalProblem

0 commit comments

Comments
 (0)