Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Keys not found during inference #136

Open
fronos opened this issue Dec 10, 2024 · 1 comment
Open

Keys not found during inference #136

fronos opened this issue Dec 10, 2024 · 1 comment
Labels
bug Something isn't working

Comments

@fronos
Copy link

fronos commented Dec 10, 2024

Describe the bug

I'm trained custom model Yolo v9-m from scratch. During inference I noticed that logger throwing messages:

WARNING  ⚠ Weight Not Found for key: 19.conv                                                                                                                                      yolo.py:152
WARNING  ⚠ Weight Not Found for key: 21.conv3.0.conv2                                                                                                                             yolo.py:152
WARNING  ⚠ Weight Not Found for key: 34.conv3.0.conv2                                                                                                                             yolo.py:152
WARNING  ⚠ Weight Not Found for key: 38.heads.1.anchor_conv.0                                                                                                                     yolo.py:152
WARNING  ⚠ Weight Not Found for key: 22.heads.2.anc2vec  

The main reasons of this warnings:

  1. Lightning save weights as a key of file with extension .ckpt.
  2. Keys of initialized model and trained model mismatched. Need add model.model into name of key

System Info (please complete the following ## information):

  • OS: [e.g. Ubuntu 20.04]
  • Python Version: 3.10.16
  • PyTorch Version: 2.5.1
  • Lightning: 2.4.0
  • CUDA/cuDNN/MPS Version: 12.4
  • YOLO Model Version: YOLOv9-m
@fronos fronos added the bug Something isn't working label Dec 10, 2024
@fronos fronos changed the title Missed keys during inference Keys not found during inference Dec 10, 2024
@cansik
Copy link

cansik commented Dec 11, 2024

I have the same issue here for a custom YOLOv9-t model. Basically, the custom trained checkpoint won't be able to load for inference. What I've done is adjusting the create_model method in yolo.py to allow to read ckpt files. This seems to work!

def create_model(model_cfg: ModelConfig, weight_path: Union[bool, Path] = True, class_num: int = 80) -> YOLO:
    """Constructs and returns a model from a Dictionary configuration file.

    Args:
        config_file (dict): The configuration file of the model.

    Returns:
        YOLO: An instance of the model defined by the given configuration.
    """
    OmegaConf.set_struct(model_cfg, False)
    model = YOLO(model_cfg, class_num)
    if weight_path:
        if weight_path == True:
            weight_path = Path("weights") / f"{model_cfg.name}.pt"
        elif isinstance(weight_path, str):
            weight_path = Path(weight_path)

        if not weight_path.exists():
            logger.info(f"🌐 Weight {weight_path} not found, try downloading")
            prepare_weight(weight_path=weight_path)
        if weight_path.exists():
            if weight_path.suffix == ".ckpt":
                checkpoint = torch.load(weight_path)
                state_dict = checkpoint['state_dict']

                # Fix the keys in the checkpoint
                new_state_dict = {}

                for key, value in state_dict.items():
                    if key.startswith("model.model."):
                        new_key = key.replace("model.model.", "model.", 1)
                        new_state_dict[new_key] = value
                    else:
                        new_state_dict[key] = value

                # Load new state dict
                model.load_state_dict(new_state_dict, strict=False)

                for name, param in model.named_parameters():
                    if name not in new_state_dict:
                        print(f"Missing weight for layer: {name}")

                logger.info(":white_check_mark: Success load model & checkpoint")
            else:
                model.save_load_weights(weight_path)
                logger.info(":white_check_mark: Success load model & weight")

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants