@@ -72,8 +72,8 @@ def __init__(self, inputs: List[nodes.InputTensor], outputs: List[nodes.OutputTe
72
72
73
73
self ._req_input_stats = {s : self ._collect_required_stats (self ._prep , s ) for s in SCOPES }
74
74
self ._req_output_stats = {s : self ._collect_required_stats (self ._post , s ) for s in SCOPES }
75
- if any ( self ._req_output_stats [s ] for s in SCOPES ) :
76
- raise NotImplementedError ("computing statistics for output tensors not yet implemented" )
75
+ if self ._req_output_stats [DATASET ] :
76
+ raise NotImplementedError ("computing statistics for output tensors per dataset is not yet implemented" )
77
77
78
78
self ._computed_dataset_stats : Optional [Dict [str , Dict [Measure , Any ]]] = None
79
79
@@ -111,8 +111,10 @@ def apply_postprocessing(
111
111
) -> Tuple [List [xr .DataArray ], Dict [str , Dict [Measure , Any ]]]:
112
112
assert len (output_tensors ) == len (self .output_tensor_names )
113
113
tensors = dict (zip (self .output_tensor_names , output_tensors ))
114
- sample_stats = input_sample_statistics
115
- sample_stats .update (self .compute_sample_statistics (tensors , self ._req_output_stats [SAMPLE ]))
114
+ sample_stats = {
115
+ ** input_sample_statistics ,
116
+ ** self .compute_sample_statistics (tensors , self ._req_output_stats [SAMPLE ]),
117
+ }
116
118
for proc in self ._post :
117
119
proc .set_computed_sample_statistics (sample_stats )
118
120
tensors [proc .tensor_name ] = proc .apply (tensors [proc .tensor_name ])
0 commit comments