Skip to content

Commit

Permalink
some review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
casperdcl committed Jun 28, 2024
1 parent 5aec470 commit ac6401c
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 3 deletions.
1 change: 1 addition & 0 deletions main_SGD.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def __init__(self, data: Dataset, num_subsets: int = 7, step_size: float = 1e-10
"""
Initialisation function, setting up data & (hyper)parameters.
NB: in practice, `num_subsets` should likely be determined from the data.
WARNING: we also currently ignore the non-negativity constraint here.
This is just an example. Try to modify and improve it!
"""
data_sub, acq_models, obj_funs = partitioner.data_partition(data.acquired_data, data.additive_term,
Expand Down
8 changes: 5 additions & 3 deletions petric.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,19 +77,21 @@ def __call__(self, algo: Algorithm):

class MetricsWithTimeout(cbks.Callback):
"""Stops the algorithm after `seconds`"""
def __init__(self, seconds=300, outdir=OUTDIR, transverse_slice=None, coronal_slice=None, ground_truth=None,
def __init__(self, seconds=300, outdir=OUTDIR, transverse_slice=None, coronal_slice=None, reference_image=None,
verbose=1):
super().__init__(verbose)
self.callbacks = [
cbks.ProgressCallback(),
SaveIters(outdir=outdir),
(tb_cbk := TensorBoard(logdir=outdir, transverse_slice=transverse_slice, coronal_slice=coronal_slice))]

if ground_truth:
if reference_image:
roi_image_dict = {f'S{i}': STIR.ImageData(f'S{i}.hv') for i in range(1, 8)}
# NB: these metrics are for testing only.
# The final evaluation will use metrics described in https://github.com/SyneRBI/PETRIC/wiki
self.callbacks.append(
ImageQualityCallback(
ground_truth, tb_cbk.tb, roi_mask_dict=roi_image_dict, metrics_dict={
reference_image, tb_cbk.tb, roi_mask_dict=roi_image_dict, metrics_dict={
'MSE': mean_squared_error, 'MAE': self.mean_absolute_error, 'PSNR': peak_signal_noise_ratio},
statistics_dict={'MEAN': np.mean, 'STDDEV': np.std, 'MAX': np.max}))

Expand Down

0 comments on commit ac6401c

Please sign in to comment.