@@ -202,15 +202,12 @@ def __init__(self, func, idx, ad_block_tag=None):
202202 def evaluate_adj_component (self , inputs , adj_inputs , block_variable , idx ,
203203 prepared = None ):
204204 eval_adj = firedrake .Cofunction (block_variable .output .function_space ().dual ())
205- if type (adj_inputs [0 ]) is firedrake .Cofunction :
206- eval_adj .sub (self .sub_idx ).assign (adj_inputs [0 ])
207- else :
208- eval_adj .sub (self .sub_idx ).assign (adj_inputs [0 ].function )
205+ eval_adj .sub (self .sub_idx ).assign (adj_inputs [0 ])
209206 return eval_adj
210207
211208 def evaluate_tlm_component (self , inputs , tlm_inputs , block_variable , idx ,
212209 prepared = None ):
213- return firedrake . Function . sub ( tlm_inputs [0 ], self .sub_idx )
210+ return tlm_inputs [0 ]. sub ( self .sub_idx )
214211
215212 def evaluate_hessian_component (self , inputs , hessian_inputs , adj_inputs ,
216213 block_variable , idx ,
@@ -220,9 +217,7 @@ def evaluate_hessian_component(self, inputs, hessian_inputs, adj_inputs,
220217 return eval_hessian
221218
222219 def recompute_component (self , inputs , block_variable , idx , prepared ):
223- return maybe_disk_checkpoint (
224- firedrake .Function .sub (inputs [0 ], self .sub_idx )
225- )
220+ return maybe_disk_checkpoint (inputs [0 ].sub (self .sub_idx ))
226221
227222 def __str__ (self ):
228223 return f"{ self .get_dependencies ()[0 ]} [{ self .sub_idx } ]"
@@ -264,23 +259,31 @@ def evaluate_adj_component(self, inputs, adj_inputs, block_variable, idx,
264259 adj_inputs [0 ].subfunctions [self .sub_idx ].zero ()
265260 return adj_inputs [0 ]
266261
267- def evaluate_tlm (self , markings = False ):
268- tlm_input = self .get_dependencies ()[0 ].tlm_value
269- if tlm_input is None :
270- return
271- output = self .get_outputs ()[0 ]
272- if markings and not output .marked_in_path :
273- return
274- fs = output .output .function_space ()
275- f = type (output .output )(fs )
276- output .add_tlm_output (
277- type (output .output ).assign (f .sub (self .sub_idx ), tlm_input )
278- )
262+ def evaluate_tlm_component (self , inputs , tlm_inputs , block_variable , idx , prepared = None ):
263+ sub_tlm = tlm_inputs [0 ]
264+ parent_in = tlm_inputs [1 ]
265+
266+ if sub_tlm is None and parent_in is None :
267+ return None
268+
269+ output = self .get_outputs ()[0 ].output
270+ parent_out = type (output )(output .function_space ())
271+
272+ if parent_in is not None :
273+ parent_out .assign (parent_in )
274+ if sub_tlm is not None :
275+ parent_out .sub (self .sub_idx ).assign (sub_tlm )
276+
277+ return parent_out
279278
280279 def evaluate_hessian_component (self , inputs , hessian_inputs , adj_inputs ,
281280 block_variable , idx ,
282281 relevant_dependencies , prepared = None ):
283- return hessian_inputs [0 ]
282+ if idx == 0 :
283+ return hessian_inputs [0 ].subfunctions [self .sub_idx ].copy (deepcopy = True )
284+ else :
285+ hessian_inputs [0 ].subfunctions [self .sub_idx ].zero ()
286+ return hessian_inputs [0 ]
284287
285288 def recompute_component (self , inputs , block_variable , idx , prepared ):
286289 sub_func = inputs [0 ]
0 commit comments