@@ -30,6 +30,7 @@ class Solver(Enum):
30
30
FORWARD = 0
31
31
ADJOINT = 1
32
32
TLM = 2
33
+ HESSIAN = 3
33
34
34
35
35
36
class GenericSolveBlock (Block ):
@@ -221,6 +222,9 @@ def _assemble_and_solve_adj_eq(self, dFdu_adj_form, dJdu, compute_bdy):
221
222
222
223
return adj_sol , adj_sol_bdy
223
224
225
+ def _hessian_solve (self , * args ):
226
+ return self ._assemble_and_solve_adj_eq (* args )
227
+
224
228
def _compute_adj_bdy (self , adj_sol , adj_sol_bdy , dFdu_adj_form , dJdu ):
225
229
adj_sol_bdy = firedrake .assemble (dJdu - firedrake .action (dFdu_adj_form , adj_sol ))
226
230
return adj_sol_bdy .riesz_representation ("l2" )
@@ -379,8 +383,7 @@ def _assemble_and_solve_soa_eq(self, dFdu_form, adj_sol, hessian_input,
379
383
b = self ._assemble_soa_eq_rhs (dFdu_form , adj_sol , hessian_input ,
380
384
d2Fdu2 )
381
385
dFdu_form = firedrake .adjoint (dFdu_form )
382
- adj_sol2 , adj_sol2_bdy = self ._assemble_and_solve_adj_eq (dFdu_form , b ,
383
- compute_bdy )
386
+ adj_sol2 , adj_sol2_bdy = self ._hessian_solve (dFdu_form , b , compute_bdy )
384
387
if self .adj2_cb is not None :
385
388
self .adj2_cb (adj_sol2 )
386
389
if self .adj2_bdy_cb is not None and compute_bdy :
@@ -679,6 +682,22 @@ def _adjoint_solve(self, dJdu, compute_bdy):
679
682
u_sol , adj_sol_bdy , jac_adj , dJdu_copy )
680
683
return u_sol , adj_sol_bdy
681
684
685
+ def _hessian_solve (self , adj_form , rhs , compute_bdy ):
686
+ # self._ad_solver_replace_forms(Solver.HESSIAN)
687
+ # self._ad_solvers["hessian_lvs"].invalidate_jacobian()
688
+ self ._ad_solvers ["hessian_lvs" ]._problem .F ._components [1 ].assign (rhs )
689
+ self ._ad_solvers ["hessian_lvs" ].solve ()
690
+ u_sol = self ._ad_solvers ["hessian_lvs" ]._problem .u
691
+
692
+ adj_sol_bdy = None
693
+ if compute_bdy :
694
+ jac_adj = self ._ad_solvers ["hessian_lvs" ]._problem .J
695
+ adj_sol_bdy = self ._compute_adj_bdy (
696
+ u_sol , adj_sol_bdy , jac_adj , rhs .copy ()
697
+ )
698
+
699
+ return u_sol , adj_sol_bdy
700
+
682
701
def _ad_assign_map (self , form , solver ):
683
702
if solver == Solver .FORWARD :
684
703
count_map = self ._ad_solvers ["forward_nlvs" ]._problem ._ad_count_map
@@ -697,8 +716,10 @@ def _ad_assign_map(self, form, solver):
697
716
firedrake .Cofunction )):
698
717
coeff_count = coeff .count ()
699
718
if coeff_count in form_ad_count_map :
700
- assign_map [form_ad_count_map [coeff_count ]] = \
701
- block_variable .saved_output
719
+ if solver == Solver .HESSIAN :
720
+ assign_map [form_ad_count_map [coeff_count ]] = block_variable .tlm_value
721
+ else :
722
+ assign_map [form_ad_count_map [coeff_count ]] = block_variable .saved_output
702
723
703
724
if (
704
725
solver == Solver .ADJOINT
@@ -709,6 +730,7 @@ def _ad_assign_map(self, form, solver):
709
730
if coeff_count in form_ad_count_map :
710
731
assign_map [form_ad_count_map [coeff_count ]] = \
711
732
block_variable .saved_output
733
+
712
734
return assign_map
713
735
714
736
def _ad_assign_coefficients (self , form , solver ):
@@ -728,6 +750,10 @@ def _ad_solver_replace_forms(self, solver=Solver.FORWARD):
728
750
self ._ad_assign_coefficients (
729
751
self ._ad_solvers ["tlm_lvs" ]._problem .J , solver
730
752
)
753
+ elif solver == Solver .HESSIAN :
754
+ self ._ad_assign_coefficients (
755
+ self ._ad_solvers ["hessian_lvs" ]._problem .J , solver
756
+ )
731
757
732
758
def prepare_evaluate_adj (self , inputs , adj_inputs , relevant_dependencies ):
733
759
compute_bdy = self ._should_compute_boundary_adjoint (
@@ -851,11 +877,6 @@ def evaluate_tlm_component(self, inputs, tlm_inputs, block_variable, idx,
851
877
852
878
self ._ad_solvers ["tlm_lvs" ].solve ()
853
879
return self ._ad_solvers ["tlm_lvs" ]._problem .u
854
- # return self._assemble_and_solve_tlm_eq(
855
- # firedrake.assemble(dFdu, bcs=bcs, **self.assemble_kwargs),
856
- # dFdm, dudm, bcs
857
- # )
858
-
859
880
860
881
861
882
class ProjectBlock (SolveVarFormBlock ):
0 commit comments