Skip to content

Commit 751bca1

Browse files
committed
Tidying
1 parent 90903a7 commit 751bca1

File tree

2 files changed

+14
-14
lines changed

2 files changed

+14
-14
lines changed

firedrake/adjoint/transformed_functional.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -264,9 +264,6 @@ def __init__(self, functional, controls, *, space_D=None, riesz_map=None, alpha=
264264
for space_D in self._space_D)
265265
self._controls = Enlist(Enlist(controls).delist(self._controls))
266266

267-
self._alpha = alpha
268-
self._m_k = None
269-
270267
if riesz_map is None:
271268
riesz_map = tuple(map(L2RieszMap, self._space))
272269
self._riesz_map = Enlist(riesz_map)
@@ -275,16 +272,19 @@ def __init__(self, functional, controls, *, space_D=None, riesz_map=None, alpha=
275272
self._C = tuple(L2Cholesky(space_D, constant_jacobian=riesz_map.constant_jacobian)
276273
for space_D, riesz_map in zip(self._space_D, self._riesz_map))
277274

275+
self._alpha = alpha
276+
self._m_k = None
277+
278278
# Map the initial guess
279-
controls_t = self._primal_transform(tuple(control.control for control in self._J.controls), apply_riesz=False)
279+
controls_t = self._dual_transform(tuple(control.control for control in self._J.controls), apply_riesz=False)
280280
for control, control_t in zip(self._controls, controls_t):
281281
control.control.assign(control_t)
282282

283283
@property
284284
def controls(self) -> Enlist[Control]:
285285
return Enlist(self._controls.delist())
286286

287-
def _primal_transform(self, u, u_D=None, *, apply_riesz=False):
287+
def _dual_transform(self, u, u_D=None, *, apply_riesz=False):
288288
u = Enlist(u)
289289
if len(u) != len(self.controls):
290290
raise ValueError("Invalid length")
@@ -311,7 +311,7 @@ def transform(C, u, u_D, space, space_D, riesz_map):
311311
v = tuple(map(transform, self._C, u, u_D, self._space, self._space_D, self._riesz_map))
312312
return u.delist(v)
313313

314-
def _dual_transform(self, u):
314+
def _primal_transform(self, u):
315315
u = Enlist(u)
316316
if len(u) != len(self.controls):
317317
raise ValueError("Invalid length")
@@ -344,17 +344,17 @@ def map_result(self, m):
344344
Returns
345345
-------
346346
347-
firedrake.Function or list[firedrake.Function]
347+
firedrake.Function or Sequence[firedrake.Function]
348348
The mapped result in the original control space.
349349
"""
350350

351-
_, m_J = self._dual_transform(m)
351+
_, m_J = self._primal_transform(m)
352352
return m_J
353353

354354
@no_annotations
355355
def __call__(self, values):
356356
values = Enlist(values)
357-
m_D, m_J = self._dual_transform(values)
357+
m_D, m_J = self._primal_transform(values)
358358
J = self._J(m_J)
359359
if self._alpha != 0:
360360
for space, space_D, m_D_i, m_J_i in zip(self._space, self._space_D, m_D, m_J):
@@ -381,7 +381,7 @@ def derivative(self, adj_input=1.0, apply_riesz=False):
381381
if fd.utils.complex_mode:
382382
raise RuntimeError("Not complex differentiable")
383383
v_alpha.append(fd.assemble(fd.Constant(self._alpha) * fd.inner(m_D - m_J, fd.TestFunction(space_D)) * fd.dx))
384-
v = self._primal_transform(u, v_alpha, apply_riesz=True)
384+
v = self._dual_transform(u, v_alpha, apply_riesz=True)
385385
if apply_riesz:
386386
v = tuple(v_i._ad_convert_riesz(v_i, riesz_map=control.riesz_map)
387387
for v_i, control in zip(v, self.controls))
@@ -393,7 +393,7 @@ def hessian(self, m_dot, hessian_input=None, evaluate_tlm=True, apply_riesz=Fals
393393
raise NotImplementedError("hessian_input not None not supported")
394394

395395
m_dot = Enlist(m_dot)
396-
m_dot_D, m_dot_J = self._dual_transform(m_dot)
396+
m_dot_D, m_dot_J = self._primal_transform(m_dot)
397397
u = Enlist(self._J.hessian(m_dot.delist(m_dot_J), evaluate_tlm=evaluate_tlm))
398398

399399
if self._alpha == 0:
@@ -407,7 +407,7 @@ def hessian(self, m_dot, hessian_input=None, evaluate_tlm=True, apply_riesz=Fals
407407
if fd.utils.complex_mode:
408408
raise RuntimeError("Not complex differentiable")
409409
v_alpha.append(fd.assemble(fd.Constant(self._alpha) * fd.inner(m_dot_D_i - m_dot_J_i, fd.TestFunction(space_D)) * fd.dx))
410-
v = self._primal_transform(u, v_alpha, apply_riesz=True)
410+
v = self._dual_transform(u, v_alpha, apply_riesz=True)
411411
if apply_riesz:
412412
v = tuple(v_i._ad_convert_riesz(v_i, riesz_map=control.riesz_map)
413413
for v_i, control in zip(v, self.controls))
@@ -416,7 +416,7 @@ def hessian(self, m_dot, hessian_input=None, evaluate_tlm=True, apply_riesz=Fals
416416
@no_annotations
417417
def tlm(self, m_dot):
418418
m_dot = Enlist(m_dot)
419-
m_dot_D, m_dot_J = self._dual_transform(m_dot)
419+
m_dot_D, m_dot_J = self._primal_transform(m_dot)
420420
tau_J = self._J.tlm(m_dot.delist(m_dot_J))
421421

422422
if self._alpha != 0:

tests/firedrake/adjoint/test_transformed_functional.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -247,7 +247,7 @@ def error_norm(m):
247247
error_norm_opt = error_norm(m_opt)
248248
print(f"{error_norm_opt=:.6g}")
249249
assert 1e-2 < error_norm_opt < 5e-2
250-
assert J_hat._test_transformed_functional__ncalls > 22 # == 25
250+
assert J_hat._test_transformed_functional__ncalls > 22 # == 24
251251

252252
J_hat = L2TransformedFunctional(J, c, alpha=1e-5)
253253

0 commit comments

Comments
 (0)