@@ -264,9 +264,6 @@ def __init__(self, functional, controls, *, space_D=None, riesz_map=None, alpha=
264264 for space_D in self ._space_D )
265265 self ._controls = Enlist (Enlist (controls ).delist (self ._controls ))
266266
267- self ._alpha = alpha
268- self ._m_k = None
269-
270267 if riesz_map is None :
271268 riesz_map = tuple (map (L2RieszMap , self ._space ))
272269 self ._riesz_map = Enlist (riesz_map )
@@ -275,16 +272,19 @@ def __init__(self, functional, controls, *, space_D=None, riesz_map=None, alpha=
275272 self ._C = tuple (L2Cholesky (space_D , constant_jacobian = riesz_map .constant_jacobian )
276273 for space_D , riesz_map in zip (self ._space_D , self ._riesz_map ))
277274
275+ self ._alpha = alpha
276+ self ._m_k = None
277+
278278 # Map the initial guess
279- controls_t = self ._primal_transform (tuple (control .control for control in self ._J .controls ), apply_riesz = False )
279+ controls_t = self ._dual_transform (tuple (control .control for control in self ._J .controls ), apply_riesz = False )
280280 for control , control_t in zip (self ._controls , controls_t ):
281281 control .control .assign (control_t )
282282
283283 @property
284284 def controls (self ) -> Enlist [Control ]:
285285 return Enlist (self ._controls .delist ())
286286
287- def _primal_transform (self , u , u_D = None , * , apply_riesz = False ):
287+ def _dual_transform (self , u , u_D = None , * , apply_riesz = False ):
288288 u = Enlist (u )
289289 if len (u ) != len (self .controls ):
290290 raise ValueError ("Invalid length" )
@@ -311,7 +311,7 @@ def transform(C, u, u_D, space, space_D, riesz_map):
311311 v = tuple (map (transform , self ._C , u , u_D , self ._space , self ._space_D , self ._riesz_map ))
312312 return u .delist (v )
313313
314- def _dual_transform (self , u ):
314+ def _primal_transform (self , u ):
315315 u = Enlist (u )
316316 if len (u ) != len (self .controls ):
317317 raise ValueError ("Invalid length" )
@@ -344,17 +344,17 @@ def map_result(self, m):
344344 Returns
345345 -------
346346
347- firedrake.Function or list [firedrake.Function]
347+ firedrake.Function or Sequence [firedrake.Function]
348348 The mapped result in the original control space.
349349 """
350350
351- _ , m_J = self ._dual_transform (m )
351+ _ , m_J = self ._primal_transform (m )
352352 return m_J
353353
354354 @no_annotations
355355 def __call__ (self , values ):
356356 values = Enlist (values )
357- m_D , m_J = self ._dual_transform (values )
357+ m_D , m_J = self ._primal_transform (values )
358358 J = self ._J (m_J )
359359 if self ._alpha != 0 :
360360 for space , space_D , m_D_i , m_J_i in zip (self ._space , self ._space_D , m_D , m_J ):
@@ -381,7 +381,7 @@ def derivative(self, adj_input=1.0, apply_riesz=False):
381381 if fd .utils .complex_mode :
382382 raise RuntimeError ("Not complex differentiable" )
383383 v_alpha .append (fd .assemble (fd .Constant (self ._alpha ) * fd .inner (m_D - m_J , fd .TestFunction (space_D )) * fd .dx ))
384- v = self ._primal_transform (u , v_alpha , apply_riesz = True )
384+ v = self ._dual_transform (u , v_alpha , apply_riesz = True )
385385 if apply_riesz :
386386 v = tuple (v_i ._ad_convert_riesz (v_i , riesz_map = control .riesz_map )
387387 for v_i , control in zip (v , self .controls ))
@@ -393,7 +393,7 @@ def hessian(self, m_dot, hessian_input=None, evaluate_tlm=True, apply_riesz=Fals
393393 raise NotImplementedError ("hessian_input not None not supported" )
394394
395395 m_dot = Enlist (m_dot )
396- m_dot_D , m_dot_J = self ._dual_transform (m_dot )
396+ m_dot_D , m_dot_J = self ._primal_transform (m_dot )
397397 u = Enlist (self ._J .hessian (m_dot .delist (m_dot_J ), evaluate_tlm = evaluate_tlm ))
398398
399399 if self ._alpha == 0 :
@@ -407,7 +407,7 @@ def hessian(self, m_dot, hessian_input=None, evaluate_tlm=True, apply_riesz=Fals
407407 if fd .utils .complex_mode :
408408 raise RuntimeError ("Not complex differentiable" )
409409 v_alpha .append (fd .assemble (fd .Constant (self ._alpha ) * fd .inner (m_dot_D_i - m_dot_J_i , fd .TestFunction (space_D )) * fd .dx ))
410- v = self ._primal_transform (u , v_alpha , apply_riesz = True )
410+ v = self ._dual_transform (u , v_alpha , apply_riesz = True )
411411 if apply_riesz :
412412 v = tuple (v_i ._ad_convert_riesz (v_i , riesz_map = control .riesz_map )
413413 for v_i , control in zip (v , self .controls ))
@@ -416,7 +416,7 @@ def hessian(self, m_dot, hessian_input=None, evaluate_tlm=True, apply_riesz=Fals
416416 @no_annotations
417417 def tlm (self , m_dot ):
418418 m_dot = Enlist (m_dot )
419- m_dot_D , m_dot_J = self ._dual_transform (m_dot )
419+ m_dot_D , m_dot_J = self ._primal_transform (m_dot )
420420 tau_J = self ._J .tlm (m_dot .delist (m_dot_J ))
421421
422422 if self ._alpha != 0 :
0 commit comments