Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 56 additions & 15 deletions element_deeplabcut/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,7 @@ class Model(dj.Manual):
snapshotindex (int): Which snapshot for prediction (if -1, latest).
shuffle (int): Which shuffle of the training dataset.
trainingsetindex (int): Which training set fraction to generate model.
engine (str): Engine used for model. Either 'tensorflow' or 'pytorch'.
scorer ( varchar(64) ): Scorer/network name - DLC's GetScorerName().
config_template (longblob): Dictionary of the config for analyze_videos().
project_path ( varchar(255) ): DLC's project_path in config relative to root.
Expand All @@ -329,7 +330,8 @@ class Model(dj.Manual):
snapshotindex : int # which snapshot for prediction (if -1, latest)
shuffle : int # Shuffle (1) or not (0)
trainingsetindex : int # Index of training fraction list in config.yaml
unique index (task, date, iteration, shuffle, snapshotindex, trainingsetindex)
engine='tensorflow' : varchar(16) # Engine used for model. Either 'tensorflow' or 'pytorch'
unique index (task, date, iteration, shuffle, snapshotindex, trainingsetindex, engine)
scorer : varchar(64) # Scorer/network name - DLC's GetScorerName()
config_template : longblob # Dictionary of the config for analyze_videos()
project_path : varchar(255) # DLC's project_path in config relative to root
Expand Down Expand Up @@ -378,9 +380,6 @@ def insert_new_model(
prompt (bool): Optional. Prompt the user with all info before inserting.
params (dict): Optional. If dlc_config is path, dict of override items
"""

from deeplabcut.utils.auxiliaryfunctions import GetScorerName # isort:skip

# handle dlc_config being a yaml file
dlc_config_fp = find_full_path(get_dlc_root_data_dir(), Path(dlc_config))
assert dlc_config_fp.exists(), (
Expand Down Expand Up @@ -409,16 +408,37 @@ def insert_new_model(
for attribute in needed_attributes:
assert attribute in dlc_config, f"Couldn't find {attribute} in config"

# ---- Get scorer name ----
# "or 'f'" below covers case where config returns None. str_to_bool handles else
scorer_legacy = str_to_bool(dlc_config.get("scorer_legacy", "f"))
engine = dlc_config.get("engine")
if engine is None:
logger.warning(
"DLC engine not specified in config file. Defaulting to TensorFlow."
)
engine = "tensorflow"

if engine == "tensorflow":
from deeplabcut.utils.auxiliaryfunctions import GetScorerName # isort:skip

# ---- Get scorer name ----
# "or 'f'" below covers case where config returns None. str_to_bool handles else
scorer_legacy = str_to_bool(dlc_config.get("scorer_legacy", "f"))
dlc_scorer = GetScorerName(
cfg=dlc_config,
shuffle=shuffle,
trainFraction=dlc_config["TrainingFraction"][int(trainingsetindex)],
modelprefix=model_prefix,
)[scorer_legacy]
elif engine == "pytorch":
from deeplabcut.pose_estimation_pytorch.apis.utils import get_scorer_name

dlc_scorer = get_scorer_name(
cfg=dlc_config,
shuffle=shuffle,
train_fraction=dlc_config["TrainingFraction"][int(trainingsetindex)],
modelprefix=model_prefix,
)
else:
raise ValueError(f"Unknow engine type {engine}")

dlc_scorer = GetScorerName(
cfg=dlc_config,
shuffle=shuffle,
trainFraction=dlc_config["TrainingFraction"][int(trainingsetindex)],
modelprefix=model_prefix,
)[scorer_legacy]
if dlc_config["snapshotindex"] == -1:
dlc_scorer = "".join(dlc_scorer.split("_")[:-1])

Expand All @@ -433,6 +453,7 @@ def insert_new_model(
"snapshotindex": dlc_config["snapshotindex"],
"shuffle": shuffle,
"trainingsetindex": int(trainingsetindex),
"engine": engine,
"project_path": project_path.relative_to(root_dir).as_posix(),
"paramset_idx": paramset_idx,
"config_template": dlc_config,
Expand Down Expand Up @@ -719,7 +740,16 @@ def make(self, key):
PoseEstimationTask.update1(
{**key, "pose_estimation_output_dir": output_dir.as_posix()}
)
output_dir = find_full_path(get_dlc_root_data_dir(), output_dir)

try:
output_dir = find_full_path(get_dlc_root_data_dir(), output_dir)
except FileNotFoundError as e:
if task_mode == "trigger":
processed_dir = Path(get_dlc_processed_data_dir())
output_dir = processed_dir / output_dir
output_dir.mkdir(parents=True, exist_ok=True)
else:
raise e

# Trigger PoseEstimation
if task_mode == "trigger":
Expand Down Expand Up @@ -756,7 +786,18 @@ def make(self, key):
output_directory=output_dir,
)
def do_analyze_videos():
from deeplabcut.pose_estimation_tensorflow import analyze_videos
engine = dlc_model_.get("engine")
if engine is None:
logger.warning(
"DLC engine not specified in config file. Defaulting to TensorFlow."
)
engine = "tensorflow"
if engine == "pytorch":
from deeplabcut.pose_estimation_pytorch import analyze_videos
elif engine == "tensorflow":
from deeplabcut.pose_estimation_tensorflow import analyze_videos
else:
raise ValueError(f"Unknow engine type {engine}")

# ---- Build and save DLC configuration (yaml) file ----
dlc_config = dlc_model_["config_template"]
Expand Down
2 changes: 1 addition & 1 deletion element_deeplabcut/version.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@
Package metadata
"""

__version__ = "0.3.3"
__version__ = "0.4.0"
3 changes: 3 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,5 +46,8 @@
"element-interface @ git+https://github.com/datajoint/element-interface.git",
],
"tests": ["pytest", "pytest-cov", "shutils"],
"dlc-pytorch": [
"deeplabcut @ git+https://github.com/DeepLabCut/DeepLabCut.git@pytorch_dlc"
],
},
)