From d0f84a7c914cff34552c028c71962d5bc9b0cd06 Mon Sep 17 00:00:00 2001 From: Thinh Nguyen Date: Mon, 24 Feb 2025 16:51:13 -0600 Subject: [PATCH] feat(pytorch_dlc): support the latest DLC with pytorch engine --- element_deeplabcut/model.py | 71 +++++++++++++++++++++++++++-------- element_deeplabcut/version.py | 2 +- setup.py | 3 ++ 3 files changed, 60 insertions(+), 16 deletions(-) diff --git a/element_deeplabcut/model.py b/element_deeplabcut/model.py index b1b1c51..bde24b9 100644 --- a/element_deeplabcut/model.py +++ b/element_deeplabcut/model.py @@ -308,6 +308,7 @@ class Model(dj.Manual): snapshotindex (int): Which snapshot for prediction (if -1, latest). shuffle (int): Which shuffle of the training dataset. trainingsetindex (int): Which training set fraction to generate model. + engine (str): Engine used for model. Either 'tensorflow' or 'pytorch'. scorer ( varchar(64) ): Scorer/network name - DLC's GetScorerName(). config_template (longblob): Dictionary of the config for analyze_videos(). project_path ( varchar(255) ): DLC's project_path in config relative to root. @@ -329,7 +330,8 @@ class Model(dj.Manual): snapshotindex : int # which snapshot for prediction (if -1, latest) shuffle : int # Shuffle (1) or not (0) trainingsetindex : int # Index of training fraction list in config.yaml - unique index (task, date, iteration, shuffle, snapshotindex, trainingsetindex) + engine='tensorflow' : varchar(16) # Engine used for model. Either 'tensorflow' or 'pytorch' + unique index (task, date, iteration, shuffle, snapshotindex, trainingsetindex, engine) scorer : varchar(64) # Scorer/network name - DLC's GetScorerName() config_template : longblob # Dictionary of the config for analyze_videos() project_path : varchar(255) # DLC's project_path in config relative to root @@ -378,9 +380,6 @@ def insert_new_model( prompt (bool): Optional. Prompt the user with all info before inserting. params (dict): Optional. If dlc_config is path, dict of override items """ - - from deeplabcut.utils.auxiliaryfunctions import GetScorerName # isort:skip - # handle dlc_config being a yaml file dlc_config_fp = find_full_path(get_dlc_root_data_dir(), Path(dlc_config)) assert dlc_config_fp.exists(), ( @@ -409,16 +408,37 @@ def insert_new_model( for attribute in needed_attributes: assert attribute in dlc_config, f"Couldn't find {attribute} in config" - # ---- Get scorer name ---- - # "or 'f'" below covers case where config returns None. str_to_bool handles else - scorer_legacy = str_to_bool(dlc_config.get("scorer_legacy", "f")) + engine = dlc_config.get("engine") + if engine is None: + logger.warning( + "DLC engine not specified in config file. Defaulting to TensorFlow." + ) + engine = "tensorflow" + + if engine == "tensorflow": + from deeplabcut.utils.auxiliaryfunctions import GetScorerName # isort:skip + + # ---- Get scorer name ---- + # "or 'f'" below covers case where config returns None. str_to_bool handles else + scorer_legacy = str_to_bool(dlc_config.get("scorer_legacy", "f")) + dlc_scorer = GetScorerName( + cfg=dlc_config, + shuffle=shuffle, + trainFraction=dlc_config["TrainingFraction"][int(trainingsetindex)], + modelprefix=model_prefix, + )[scorer_legacy] + elif engine == "pytorch": + from deeplabcut.pose_estimation_pytorch.apis.utils import get_scorer_name + + dlc_scorer = get_scorer_name( + cfg=dlc_config, + shuffle=shuffle, + train_fraction=dlc_config["TrainingFraction"][int(trainingsetindex)], + modelprefix=model_prefix, + ) + else: + raise ValueError(f"Unknow engine type {engine}") - dlc_scorer = GetScorerName( - cfg=dlc_config, - shuffle=shuffle, - trainFraction=dlc_config["TrainingFraction"][int(trainingsetindex)], - modelprefix=model_prefix, - )[scorer_legacy] if dlc_config["snapshotindex"] == -1: dlc_scorer = "".join(dlc_scorer.split("_")[:-1]) @@ -433,6 +453,7 @@ def insert_new_model( "snapshotindex": dlc_config["snapshotindex"], "shuffle": shuffle, "trainingsetindex": int(trainingsetindex), + "engine": engine, "project_path": project_path.relative_to(root_dir).as_posix(), "paramset_idx": paramset_idx, "config_template": dlc_config, @@ -719,7 +740,16 @@ def make(self, key): PoseEstimationTask.update1( {**key, "pose_estimation_output_dir": output_dir.as_posix()} ) - output_dir = find_full_path(get_dlc_root_data_dir(), output_dir) + + try: + output_dir = find_full_path(get_dlc_root_data_dir(), output_dir) + except FileNotFoundError as e: + if task_mode == "trigger": + processed_dir = Path(get_dlc_processed_data_dir()) + output_dir = processed_dir / output_dir + output_dir.mkdir(parents=True, exist_ok=True) + else: + raise e # Trigger PoseEstimation if task_mode == "trigger": @@ -756,7 +786,18 @@ def make(self, key): output_directory=output_dir, ) def do_analyze_videos(): - from deeplabcut.pose_estimation_tensorflow import analyze_videos + engine = dlc_model_.get("engine") + if engine is None: + logger.warning( + "DLC engine not specified in config file. Defaulting to TensorFlow." + ) + engine = "tensorflow" + if engine == "pytorch": + from deeplabcut.pose_estimation_pytorch import analyze_videos + elif engine == "tensorflow": + from deeplabcut.pose_estimation_tensorflow import analyze_videos + else: + raise ValueError(f"Unknow engine type {engine}") # ---- Build and save DLC configuration (yaml) file ---- dlc_config = dlc_model_["config_template"] diff --git a/element_deeplabcut/version.py b/element_deeplabcut/version.py index d73e036..afa728c 100644 --- a/element_deeplabcut/version.py +++ b/element_deeplabcut/version.py @@ -2,4 +2,4 @@ Package metadata """ -__version__ = "0.3.3" +__version__ = "0.4.0" diff --git a/setup.py b/setup.py index 5603760..ba455c5 100644 --- a/setup.py +++ b/setup.py @@ -46,5 +46,8 @@ "element-interface @ git+https://github.com/datajoint/element-interface.git", ], "tests": ["pytest", "pytest-cov", "shutils"], + "dlc-pytorch": [ + "deeplabcut @ git+https://github.com/DeepLabCut/DeepLabCut.git@pytorch_dlc" + ], }, )