Skip to content

Error in prediction for Segmentation configs #379

@KartikeyKansal1

Description

@KartikeyKansal1

Hi, I'm running python cyto_dl/eval.py experiment=im2im/segmentation.yaml ckpt_path='xyz' to do prediction on saved checkpoints from segmentation training but getting the following error. This command usually works for labelfree experiments where it runs prediction for the complete data set.

[2024-04-23 07:22:36,432][cyto_dl.utils.template_utils][INFO] - Closing loggers...
Error executing job with overrides: ['experiment=im2im/segmentation.yaml', 'trainer=cpu', 'experiment_name=240422_exp2_actinseg_batch_100', 'run_name=predict_run_1', 'data.batch_size=100', 'ckpt_path=/Users/kartikeykansal/Documents/tensionGAN/actin/Segmentation_240415/240422/Experiment_2/logs/train/runs/240422_exp2_actinseg_batch_100/train_run_1/2024-04-22_20-05-34/checkpoints/last.ckpt']
Traceback (most recent call last):
  File "/Users/kartikeykansal/Documents/tensionGAN/actin/Segmentation_240415/240422/Experiment_2/cyto_dl_actinseg_exp2_staging_240422/cyto_dl/eval.py", line 99, in main
    evaluate(cfg)
  File "/Users/kartikeykansal/miniconda3/envs/cytoenv240131/lib/python3.10/site-packages/cyto_dl/utils/template_utils.py", line 56, in wrap
    raise ex
  File "/Users/kartikeykansal/miniconda3/envs/cytoenv240131/lib/python3.10/site-packages/cyto_dl/utils/template_utils.py", line 53, in wrap
    out = task_func(cfg=cfg)
  File "/Users/kartikeykansal/Documents/tensionGAN/actin/Segmentation_240415/240422/Experiment_2/cyto_dl_actinseg_exp2_staging_240422/cyto_dl/eval.py", line 87, in evaluate
    output = method(model=model, dataloaders=data, ckpt_path=cfg.ckpt_path)
  File "/Users/kartikeykansal/miniconda3/envs/cytoenv240131/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py", line 852, in predict
    return call._call_and_handle_interrupt(
  File "/Users/kartikeykansal/miniconda3/envs/cytoenv240131/lib/python3.10/site-packages/lightning/pytorch/trainer/call.py", line 43, in _call_and_handle_interrupt
    return trainer_fn(*args, **kwargs)
  File "/Users/kartikeykansal/miniconda3/envs/cytoenv240131/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py", line 894, in _predict_impl
    results = self._run(model, ckpt_path=ckpt_path)
  File "/Users/kartikeykansal/miniconda3/envs/cytoenv240131/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py", line 946, in _run
    self._checkpoint_connector._restore_modules_and_callbacks(ckpt_path)
  File "/Users/kartikeykansal/miniconda3/envs/cytoenv240131/lib/python3.10/site-packages/lightning/pytorch/trainer/connectors/checkpoint_connector.py", line 400, in _restore_modules_and_callbacks
    self.restore_model()
  File "/Users/kartikeykansal/miniconda3/envs/cytoenv240131/lib/python3.10/site-packages/lightning/pytorch/trainer/connectors/checkpoint_connector.py", line 280, in restore_model
    trainer.strategy.load_model_state_dict(self._loaded_checkpoint)
  File "/Users/kartikeykansal/miniconda3/envs/cytoenv240131/lib/python3.10/site-packages/lightning/pytorch/strategies/strategy.py", line 364, in load_model_state_dict
    self.lightning_module.load_state_dict(checkpoint["state_dict"])
  File "/Users/kartikeykansal/miniconda3/envs/cytoenv240131/lib/python3.10/site-packages/torch/nn/modules/module.py", line 2041, in load_state_dict
    raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for MultiTaskIm2Im:
	Missing key(s) in state_dict: "task_heads.seg.loss.dice.class_weight". 

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions