diff --git a/mmv_im2im/map_extractor.py b/mmv_im2im/map_extractor.py index b84acb0..6ee6139 100644 --- a/mmv_im2im/map_extractor.py +++ b/mmv_im2im/map_extractor.py @@ -135,6 +135,7 @@ def process_one_image( if self.pre_process is not None: x = self.pre_process(x) + x = x[0] # choose different inference function for different types of models # the input here is assumed to be a tensor @@ -257,6 +258,7 @@ def process_one_image( def run_inference(self): self.setup_model() + self.setup_data_processing() if "pred_slice2vol" in self.model_cfg.net: if self.model_cfg.net["pred_slice2vol"] is not None: