Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion orbit/estimators/stan_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,23 @@
import multiprocessing
from abc import abstractmethod
from copy import copy
from sys import platform, version_info

from ..exceptions import EstimatorException
from ..utils.general import update_dict
from ..utils.logger import get_logger
from ..utils.set_cmdstan_path import set_cmdstan_path
from ..utils.stan import get_compiled_stan_model
from ..utils.cmdstanpy_compat import patch_tqdm_progress_hook
from .base_estimator import BaseEstimator

logger = get_logger("orbit")

# Make sure models are using the right cmdstan folder
set_cmdstan_path()

# Apply cmdstanpy compatibility patches
patch_tqdm_progress_hook()


class StanEstimator(BaseEstimator):
"""Abstract StanEstimator with shared args for all StanEstimator child classes
Expand Down
116 changes: 116 additions & 0 deletions orbit/utils/cmdstanpy_compat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
"""
Compatibility utilities for cmdstanpy integration.

This module contains patches and workarounds for cmdstanpy compatibility issues.
"""

import os
from typing import Dict, List, Optional, Callable

from .logger import get_logger

logger = get_logger("orbit")


def patch_tqdm_progress_hook():
"""
Patch cmdstanpy progress hook to handle TQDM_DISABLE safely.

When TQDM_DISABLE=1 is set, tqdm creates disabled progress bar objects
that don't have the 'postfix' attribute. cmdstanpy assumes this attribute
exists and tries to access it, causing AttributeError.

This patch adds safe access checks to prevent the crash.

See: https://github.com/uber/orbit/issues/887
"""
# Only patch if TQDM_DISABLE is set
if os.environ.get("TQDM_DISABLE") != "1":
return

try:
import cmdstanpy.model
import re

# Store reference to original method to avoid patching multiple times
if hasattr(cmdstanpy.model.CmdStanModel, "_orbit_tqdm_patched"):
return

original_hook = getattr(
cmdstanpy.model.CmdStanModel, "_wrap_sampler_progress_hook", None
)
if original_hook is None:
return

@staticmethod
def safe_wrap_sampler_progress_hook(
chain_ids: List[int],
total: int,
) -> Optional[Callable[[str, int], None]]:
"""Safe version that handles disabled tqdm progress bars."""
try:
from tqdm import tqdm

pat = re.compile(r"Chain \[(\d*)\] (Iteration.*)")
pbars: Dict[int, tqdm] = {
chain_id: tqdm(
total=total,
bar_format="{desc} |{bar}| {elapsed} {postfix[0][value]}",
postfix=[{"value": "Status"}],
desc=f"chain {chain_id}",
colour="yellow",
)
for chain_id in chain_ids
}

def progress_hook(line: str, idx: int) -> None:
if line == "Done":
for pbar in pbars.values():
# safe postfix access
if hasattr(pbar, "postfix") and pbar.postfix:
try:
pbar.postfix[0]["value"] = "Sampling completed"
except (AttributeError, KeyError, IndexError):
pass
pbar.update(total - pbar.n)
pbar.close()
else:
match = pat.match(line)
if match:
idx = int(match.group(1))
mline = match.group(2).strip()
elif line.startswith("Iteration"):
mline = line
idx = chain_ids[idx]
else:
return

if idx in pbars:
if "Sampling" in mline and hasattr(pbars[idx], "colour"):
pbars[idx].colour = "blue"
pbars[idx].update(1)

# safe postfix access
if hasattr(pbars[idx], "postfix") and pbars[idx].postfix:
try:
pbars[idx].postfix[0]["value"] = mline
except (AttributeError, KeyError, IndexError):
pass

return progress_hook

except Exception as e:
logger.warning(
f"Progress bar setup failed: {e}. Disabling progress bars."
)
return None

# apply the patch
cmdstanpy.model.CmdStanModel._wrap_sampler_progress_hook = (
safe_wrap_sampler_progress_hook
)
cmdstanpy.model.CmdStanModel._orbit_tqdm_patched = True
logger.debug("cmdstanpy progress hook patched for TQDM_DISABLE compatibility")

except Exception as e:
logger.warning(f"Failed to patch cmdstanpy progress hook: {e}")
103 changes: 103 additions & 0 deletions tests/orbit/utils/test_cmdstanpy_compat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
import os
import pytest
from unittest.mock import patch, MagicMock

from orbit.utils.cmdstanpy_compat import patch_tqdm_progress_hook


@pytest.mark.parametrize(
"env_value",
[None, "0", "false", "true", ""],
)
def test_patch_tqdm_progress_hook_no_patch_scenarios(env_value):
"""Test that patch is not applied when TQDM_DISABLE is not '1'."""
env_dict = {"TQDM_DISABLE": env_value} if env_value is not None else {}

with patch.dict(os.environ, env_dict, clear=False):
if env_value is None and "TQDM_DISABLE" in os.environ:
del os.environ["TQDM_DISABLE"]

# Should return early without doing anything
patch_tqdm_progress_hook()
# Test passes if no exception raised


def test_patch_tqdm_progress_hook_applies_patch():
"""Test that patch is applied when TQDM_DISABLE=1."""
with patch.dict(os.environ, {"TQDM_DISABLE": "1"}):
with patch("cmdstanpy.model") as mock_cmdstanpy_model:
mock_model = MagicMock()
mock_model._wrap_sampler_progress_hook = MagicMock()
# Ensure not already patched
del mock_model._orbit_tqdm_patched
mock_cmdstanpy_model.CmdStanModel = mock_model

patch_tqdm_progress_hook()

assert mock_model._orbit_tqdm_patched is True


def test_patch_tqdm_progress_hook_no_double_patch():
"""Test that patch is not applied multiple times."""
with patch.dict(os.environ, {"TQDM_DISABLE": "1"}):
with patch("cmdstanpy.model") as mock_cmdstanpy_model:
mock_model = MagicMock()
mock_model._orbit_tqdm_patched = True # Already patched
original_hook = MagicMock()
mock_model._wrap_sampler_progress_hook = original_hook
mock_cmdstanpy_model.CmdStanModel = mock_model

patch_tqdm_progress_hook()

# Original hook should remain unchanged
assert mock_model._wrap_sampler_progress_hook is original_hook


def test_patch_tqdm_progress_hook_handles_missing_method():
"""Test graceful handling when original method doesn't exist."""
with patch.dict(os.environ, {"TQDM_DISABLE": "1"}):
with patch("cmdstanpy.model") as mock_cmdstanpy_model:
# Simple object without the method
mock_model = type("MockModel", (), {})()
mock_cmdstanpy_model.CmdStanModel = mock_model

# Should not raise exception
patch_tqdm_progress_hook()

# Should not set patched flag
assert not hasattr(mock_model, "_orbit_tqdm_patched")


def test_patch_tqdm_progress_hook_handles_import_error():
"""Test graceful handling of import errors."""
with patch.dict(os.environ, {"TQDM_DISABLE": "1"}):
with patch("cmdstanpy.model", side_effect=ImportError("No module")):
# Should not raise exception
patch_tqdm_progress_hook()


def test_integration_with_actual_cmdstanpy():
"""Integration test with actual cmdstanpy if available."""
cmdstanpy = pytest.importorskip("cmdstanpy")

with patch.dict(os.environ, {"TQDM_DISABLE": "1"}):
# Store original state
original_method = getattr(
cmdstanpy.model.CmdStanModel, "_wrap_sampler_progress_hook", None
)

try:
patch_tqdm_progress_hook()

# Verify patch was applied
assert hasattr(cmdstanpy.model.CmdStanModel, "_orbit_tqdm_patched")
assert cmdstanpy.model.CmdStanModel._orbit_tqdm_patched is True

finally:
# Clean up
if hasattr(cmdstanpy.model.CmdStanModel, "_orbit_tqdm_patched"):
delattr(cmdstanpy.model.CmdStanModel, "_orbit_tqdm_patched")
if original_method is not None:
cmdstanpy.model.CmdStanModel._wrap_sampler_progress_hook = (
original_method
)